Forbid to get quantized vector from ChunkManager (#24334)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/24381/head
yihao.dai 2023-05-24 23:03:27 +08:00 committed by GitHub
parent 1471da846d
commit 014387fd94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 53 deletions

View File

@ -212,9 +212,9 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ
var results []*segcorepb.RetrieveResults
if req.GetScope() == querypb.DataScope_Historical {
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs())
} else {
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs())
}
if err != nil {
return nil, err

View File

@ -23,7 +23,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
@ -32,7 +31,7 @@ import (
// retrieveOnSegments performs retrieve on listed segments
// all segment ids are validated before calling this function
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) {
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, error) {
var (
resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs))
errs = make([]error, len(segIDs))
@ -59,7 +58,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
errs[i] = err
return
}
if err = segment.FillIndexedFieldsData(ctx, vcm, result); err != nil {
if err = segment.ValidateIndexedFieldsData(ctx, result); err != nil {
errs[i] = err
return
}
@ -87,7 +86,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
}
// retrieveHistorical will retrieve all the target segments in historical
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error
var retrieveResults []*segcorepb.RetrieveResults
var retrieveSegmentIDs []UniqueID
@ -97,12 +96,12 @@ func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePla
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
}
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs, vcm)
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs)
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
}
// retrieveStreaming will retrieve all the target segments in streaming
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error
var retrieveResults []*segcorepb.RetrieveResults
var retrievePartIDs []UniqueID
@ -112,6 +111,6 @@ func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan
if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
}
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs, vcm)
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs)
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
}

View File

@ -126,8 +126,7 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.sealed.ID()},
nil)
[]int64{suite.sealed.ID()})
suite.NoError(err)
suite.Len(res[0].Offset, 3)
}
@ -139,8 +138,7 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
res, _, _, err := RetrieveStreaming(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.growing.ID()},
nil)
[]int64{suite.growing.ID()})
suite.NoError(err)
suite.Len(res[0].Offset, 3)
}
@ -152,8 +150,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{999},
nil)
[]int64{999})
suite.NoError(err)
suite.Len(res, 0)
}
@ -166,8 +163,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.sealed.ID()},
nil)
[]int64{suite.sealed.ID()})
suite.ErrorIs(err, ErrSegmentReleased)
suite.Len(res, 0)
}

View File

@ -28,6 +28,8 @@ import "C"
import (
"context"
"fmt"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"sort"
"sync"
"unsafe"
@ -46,7 +48,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/segcorepb"
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -473,10 +474,7 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (
return dataPath, offsetInBinlog
}
func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
vcm storage.ChunkManager,
result *segcorepb.RetrieveResults,
) error {
func (s *LocalSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
@ -484,43 +482,21 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
)
for _, fieldData := range result.FieldsData {
// If the field is not vector field, no need to download data from remote.
if !typeutil.IsVectorType(fieldData.GetType()) {
continue
}
// If the vector field doesn't have indexed, vector data is in memory
// for brute force search, no need to download data from remote.
if !s.ExistIndex(fieldData.FieldId) {
continue
}
// If the index has raw data, vector data could be obtained from index,
// no need to download data from remote.
if s.HasRawData(fieldData.FieldId) {
continue
}
index := s.GetIndex(fieldData.FieldId)
if index == nil {
continue
}
// TODO: optimize here. Now we'll read a whole file from storage every time we retrieve raw data by offset.
for i, offset := range result.Offset {
dataPath, dataOffset := s.GetFieldDataPath(index, offset)
endian := common.Endian
// fill field data that fieldData[i] = dataPath[offsetInBinlog*rowBytes, (offsetInBinlog+1)*rowBytes]
if err := fillFieldData(ctx, vcm, dataPath, fieldData, i, dataOffset, endian); err != nil {
log.Warn("failed to fill field data",
zap.Int64("offset", offset),
zap.String("dataPath", dataPath),
zap.Int64("dataOffset", dataOffset),
zap.Int64("fieldID", fieldData.GetFieldId()),
zap.String("fieldType", fieldData.GetType().String()),
zap.Error(err),
)
if !s.HasRawData(fieldData.FieldId) {
index := s.GetIndex(fieldData.FieldId)
indexType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.IndexTypeKey, index.IndexInfo.GetIndexParams())
if err != nil {
return err
}
err = fmt.Errorf("output fields for %s index is not allowed", indexType)
log.Warn("validate fields failed", zap.Error(err))
return err
}
}

View File

@ -1,13 +1,16 @@
package segments
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
storage "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -127,6 +130,56 @@ func (suite *SegmentSuite) TestHasRawData() {
suite.True(has)
}
func (suite *SegmentSuite) TestValidateIndexedFieldsData() {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{5, 4, 3, 2, 9, 8, 7, 6},
}},
},
Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6},
FieldsData: []*schemapb.FieldData{
genFieldData("int64 field", 100, schemapb.DataType_Int64,
[]int64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("float vector field", 101, schemapb.DataType_FloatVector,
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
},
}
// no index
err := suite.growing.ValidateIndexedFieldsData(context.Background(), result)
suite.NoError(err)
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.NoError(err)
// with index and has raw data
suite.sealed.AddIndex(101, &IndexedFieldInfo{
IndexInfo: &querypb.FieldIndexInfo{
FieldID: 101,
EnableIndex: true,
},
})
suite.True(suite.sealed.ExistIndex(101))
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.NoError(err)
// index doesn't have index type
DeleteSegment(suite.sealed)
suite.True(suite.sealed.ExistIndex(101))
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.Error(err)
// with index but doesn't have raw data
index := suite.sealed.GetIndex(101)
_, indexParams := genIndexParams(IndexHNSW, L2)
index.IndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams)
DeleteSegment(suite.sealed)
suite.True(suite.sealed.ExistIndex(101))
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.Error(err)
}
func TestSegment(t *testing.T) {
suite.Run(t, new(SegmentSuite))
}

View File

@ -45,6 +45,9 @@ type TestGetVectorSuite struct {
metricType string
pkType schemapb.DataType
vecType schemapb.DataType
// expected
searchFailed bool
}
func (s *TestGetVectorSuite) run() {
@ -172,6 +175,11 @@ func (s *TestGetVectorSuite) run() {
searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq)
s.Require().NoError(err)
if s.searchFailed {
s.Require().NotEqual(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
s.T().Logf("reason:%s", searchResp.GetStatus().GetReason())
return
}
s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
result := searchResp.GetResults()
@ -253,6 +261,7 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}
@ -263,6 +272,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}
@ -273,6 +283,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = true
s.run()
}
@ -283,6 +294,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = true
s.run()
}
@ -293,6 +305,7 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}
@ -303,6 +316,7 @@ func (s *TestGetVectorSuite) TestGetVector_IP() {
s.metricType = distance.IP
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}
@ -313,6 +327,7 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_VarChar
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}
@ -323,6 +338,7 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() {
s.metricType = distance.JACCARD
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_BinaryVector
s.searchFailed = false
s.run()
}
@ -334,6 +350,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run()
}
@ -344,6 +361,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
// s.metricType = distance.L2
// s.pkType = schemapb.DataType_Int64
// s.vecType = schemapb.DataType_FloatVector
// s.searchFailed = false
// s.run()
//}

View File

@ -1418,9 +1418,10 @@ class TestQueryOperation(TestcaseBase):
assert collection_w.has_index()[0]
res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
collection_w.load()
error = {ct.err_code: 1, ct.err_msg: 'not allowed'}
collection_w.query(default_term_expr, output_fields=fields,
check_task=CheckTasks.check_query_results,
check_items={exp_res: res, "with_vec": True})
check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L1)
def test_query_output_binary_vec_field_after_index(self):