mirror of https://github.com/milvus-io/milvus.git
Forbid to get quantized vector from ChunkManager (#24334)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/24381/head
parent
1471da846d
commit
014387fd94
|
@ -212,9 +212,9 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ
|
|||
|
||||
var results []*segcorepb.RetrieveResults
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
|
||||
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs())
|
||||
} else {
|
||||
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
|
||||
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -23,7 +23,6 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
|
@ -32,7 +31,7 @@ import (
|
|||
|
||||
// retrieveOnSegments performs retrieve on listed segments
|
||||
// all segment ids are validated before calling this function
|
||||
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) {
|
||||
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, error) {
|
||||
var (
|
||||
resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs))
|
||||
errs = make([]error, len(segIDs))
|
||||
|
@ -59,7 +58,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
|
|||
errs[i] = err
|
||||
return
|
||||
}
|
||||
if err = segment.FillIndexedFieldsData(ctx, vcm, result); err != nil {
|
||||
if err = segment.ValidateIndexedFieldsData(ctx, result); err != nil {
|
||||
errs[i] = err
|
||||
return
|
||||
}
|
||||
|
@ -87,7 +86,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
|
|||
}
|
||||
|
||||
// retrieveHistorical will retrieve all the target segments in historical
|
||||
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
|
||||
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
|
||||
var err error
|
||||
var retrieveResults []*segcorepb.RetrieveResults
|
||||
var retrieveSegmentIDs []UniqueID
|
||||
|
@ -97,12 +96,12 @@ func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePla
|
|||
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
|
||||
}
|
||||
|
||||
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs, vcm)
|
||||
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs)
|
||||
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
|
||||
}
|
||||
|
||||
// retrieveStreaming will retrieve all the target segments in streaming
|
||||
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
|
||||
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
|
||||
var err error
|
||||
var retrieveResults []*segcorepb.RetrieveResults
|
||||
var retrievePartIDs []UniqueID
|
||||
|
@ -112,6 +111,6 @@ func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan
|
|||
if err != nil {
|
||||
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
|
||||
}
|
||||
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs, vcm)
|
||||
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs)
|
||||
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
|
||||
}
|
||||
|
|
|
@ -126,8 +126,7 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
|
|||
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
|
||||
suite.collectionID,
|
||||
[]int64{suite.partitionID},
|
||||
[]int64{suite.sealed.ID()},
|
||||
nil)
|
||||
[]int64{suite.sealed.ID()})
|
||||
suite.NoError(err)
|
||||
suite.Len(res[0].Offset, 3)
|
||||
}
|
||||
|
@ -139,8 +138,7 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
|
|||
res, _, _, err := RetrieveStreaming(context.TODO(), suite.manager, plan,
|
||||
suite.collectionID,
|
||||
[]int64{suite.partitionID},
|
||||
[]int64{suite.growing.ID()},
|
||||
nil)
|
||||
[]int64{suite.growing.ID()})
|
||||
suite.NoError(err)
|
||||
suite.Len(res[0].Offset, 3)
|
||||
}
|
||||
|
@ -152,8 +150,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
|
|||
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
|
||||
suite.collectionID,
|
||||
[]int64{suite.partitionID},
|
||||
[]int64{999},
|
||||
nil)
|
||||
[]int64{999})
|
||||
suite.NoError(err)
|
||||
suite.Len(res, 0)
|
||||
}
|
||||
|
@ -166,8 +163,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
|
|||
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
|
||||
suite.collectionID,
|
||||
[]int64{suite.partitionID},
|
||||
[]int64{suite.sealed.ID()},
|
||||
nil)
|
||||
[]int64{suite.sealed.ID()})
|
||||
suite.ErrorIs(err, ErrSegmentReleased)
|
||||
suite.Len(res, 0)
|
||||
}
|
||||
|
|
|
@ -28,6 +28,8 @@ import "C"
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"sort"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
@ -46,7 +48,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
|
@ -473,10 +474,7 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (
|
|||
return dataPath, offsetInBinlog
|
||||
}
|
||||
|
||||
func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
|
||||
vcm storage.ChunkManager,
|
||||
result *segcorepb.RetrieveResults,
|
||||
) error {
|
||||
func (s *LocalSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", s.Collection()),
|
||||
zap.Int64("partitionID", s.Partition()),
|
||||
|
@ -484,43 +482,21 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
|
|||
)
|
||||
|
||||
for _, fieldData := range result.FieldsData {
|
||||
// If the field is not vector field, no need to download data from remote.
|
||||
if !typeutil.IsVectorType(fieldData.GetType()) {
|
||||
continue
|
||||
}
|
||||
// If the vector field doesn't have indexed, vector data is in memory
|
||||
// for brute force search, no need to download data from remote.
|
||||
if !s.ExistIndex(fieldData.FieldId) {
|
||||
continue
|
||||
}
|
||||
// If the index has raw data, vector data could be obtained from index,
|
||||
// no need to download data from remote.
|
||||
if s.HasRawData(fieldData.FieldId) {
|
||||
continue
|
||||
}
|
||||
|
||||
index := s.GetIndex(fieldData.FieldId)
|
||||
if index == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: optimize here. Now we'll read a whole file from storage every time we retrieve raw data by offset.
|
||||
for i, offset := range result.Offset {
|
||||
dataPath, dataOffset := s.GetFieldDataPath(index, offset)
|
||||
endian := common.Endian
|
||||
|
||||
// fill field data that fieldData[i] = dataPath[offsetInBinlog*rowBytes, (offsetInBinlog+1)*rowBytes]
|
||||
if err := fillFieldData(ctx, vcm, dataPath, fieldData, i, dataOffset, endian); err != nil {
|
||||
log.Warn("failed to fill field data",
|
||||
zap.Int64("offset", offset),
|
||||
zap.String("dataPath", dataPath),
|
||||
zap.Int64("dataOffset", dataOffset),
|
||||
zap.Int64("fieldID", fieldData.GetFieldId()),
|
||||
zap.String("fieldType", fieldData.GetType().String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
if !s.HasRawData(fieldData.FieldId) {
|
||||
index := s.GetIndex(fieldData.FieldId)
|
||||
indexType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.IndexTypeKey, index.IndexInfo.GetIndexParams())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = fmt.Errorf("output fields for %s index is not allowed", indexType)
|
||||
log.Warn("validate fields failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
storage "github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
|
@ -127,6 +130,56 @@ func (suite *SegmentSuite) TestHasRawData() {
|
|||
suite.True(has)
|
||||
}
|
||||
|
||||
func (suite *SegmentSuite) TestValidateIndexedFieldsData() {
|
||||
result := &segcorepb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{5, 4, 3, 2, 9, 8, 7, 6},
|
||||
}},
|
||||
},
|
||||
Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
genFieldData("int64 field", 100, schemapb.DataType_Int64,
|
||||
[]int64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
|
||||
genFieldData("float vector field", 101, schemapb.DataType_FloatVector,
|
||||
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
|
||||
},
|
||||
}
|
||||
|
||||
// no index
|
||||
err := suite.growing.ValidateIndexedFieldsData(context.Background(), result)
|
||||
suite.NoError(err)
|
||||
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
|
||||
suite.NoError(err)
|
||||
|
||||
// with index and has raw data
|
||||
suite.sealed.AddIndex(101, &IndexedFieldInfo{
|
||||
IndexInfo: &querypb.FieldIndexInfo{
|
||||
FieldID: 101,
|
||||
EnableIndex: true,
|
||||
},
|
||||
})
|
||||
suite.True(suite.sealed.ExistIndex(101))
|
||||
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
|
||||
suite.NoError(err)
|
||||
|
||||
// index doesn't have index type
|
||||
DeleteSegment(suite.sealed)
|
||||
suite.True(suite.sealed.ExistIndex(101))
|
||||
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
|
||||
suite.Error(err)
|
||||
|
||||
// with index but doesn't have raw data
|
||||
index := suite.sealed.GetIndex(101)
|
||||
_, indexParams := genIndexParams(IndexHNSW, L2)
|
||||
index.IndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams)
|
||||
DeleteSegment(suite.sealed)
|
||||
suite.True(suite.sealed.ExistIndex(101))
|
||||
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
|
||||
suite.Error(err)
|
||||
}
|
||||
|
||||
func TestSegment(t *testing.T) {
|
||||
suite.Run(t, new(SegmentSuite))
|
||||
}
|
||||
|
|
|
@ -45,6 +45,9 @@ type TestGetVectorSuite struct {
|
|||
metricType string
|
||||
pkType schemapb.DataType
|
||||
vecType schemapb.DataType
|
||||
|
||||
// expected
|
||||
searchFailed bool
|
||||
}
|
||||
|
||||
func (s *TestGetVectorSuite) run() {
|
||||
|
@ -172,6 +175,11 @@ func (s *TestGetVectorSuite) run() {
|
|||
|
||||
searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq)
|
||||
s.Require().NoError(err)
|
||||
if s.searchFailed {
|
||||
s.Require().NotEqual(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
s.T().Logf("reason:%s", searchResp.GetStatus().GetReason())
|
||||
return
|
||||
}
|
||||
s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
result := searchResp.GetResults()
|
||||
|
@ -253,6 +261,7 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -263,6 +272,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -273,6 +283,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = true
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -283,6 +294,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = true
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -293,6 +305,7 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -303,6 +316,7 @@ func (s *TestGetVectorSuite) TestGetVector_IP() {
|
|||
s.metricType = distance.IP
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -313,6 +327,7 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_VarChar
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -323,6 +338,7 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() {
|
|||
s.metricType = distance.JACCARD
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_BinaryVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -334,6 +350,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
|
|||
s.metricType = distance.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
|
@ -344,6 +361,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
|
|||
// s.metricType = distance.L2
|
||||
// s.pkType = schemapb.DataType_Int64
|
||||
// s.vecType = schemapb.DataType_FloatVector
|
||||
// s.searchFailed = false
|
||||
// s.run()
|
||||
//}
|
||||
|
||||
|
|
|
@ -1418,9 +1418,10 @@ class TestQueryOperation(TestcaseBase):
|
|||
assert collection_w.has_index()[0]
|
||||
res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
|
||||
collection_w.load()
|
||||
error = {ct.err_code: 1, ct.err_msg: 'not allowed'}
|
||||
collection_w.query(default_term_expr, output_fields=fields,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: res, "with_vec": True})
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_query_output_binary_vec_field_after_index(self):
|
||||
|
|
Loading…
Reference in New Issue