update fillRetrieveResults parameter (#6598)

* update fillRetrieveResults parameter

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* optimize fillVectorFieldData process

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/6600/head
Cai Yudong 2021-07-17 15:17:30 +08:00 committed by GitHub
parent 2d9361e0ba
commit a7b27db63e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 105 deletions

View File

@ -31,7 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/tsoutil"
@ -392,38 +392,6 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) {
sp.Finish()
}
func (q *queryCollection) getVectorOutputFieldIDs(msg queryMsg) ([]int64, error) {
var collID UniqueID
var outputFieldsID []int64
var resultFieldIDs []int64
msgType := msg.Type()
switch msgType {
case commonpb.MsgType_Retrieve:
retrieveMsg := msg.(*msgstream.RetrieveMsg)
collID = retrieveMsg.CollectionID
outputFieldsID = retrieveMsg.OutputFieldsId
case commonpb.MsgType_Search:
searchMsg := msg.(*msgstream.SearchMsg)
collID = searchMsg.CollectionID
outputFieldsID = searchMsg.OutputFieldsId
default:
return resultFieldIDs, fmt.Errorf("receive invalid msgType = %d", msgType)
}
vecFields, err := q.historical.replica.getVecFieldIDsByCollectionID(collID)
if err != nil {
return resultFieldIDs, err
}
for _, fieldID := range vecFields {
if funcutil.SliceContain(outputFieldsID, fieldID) {
resultFieldIDs = append(resultFieldIDs, fieldID)
}
}
return resultFieldIDs, nil
}
func (q *queryCollection) doUnsolvedQueryMsg() {
log.Debug("starting doUnsolvedMsg...", zap.Any("collectionID", q.collectionID))
for {
@ -787,7 +755,7 @@ func translateHits(schema *typeutil.SchemaHelper, fieldIDs []int64, rawHits [][]
finalResult.FieldsData = append(finalResult.FieldsData, newCol)
blobOffset += blobLen
default:
return nil, fmt.Errorf("unsupport data type %s", schemapb.DataType_name[int32(fieldMeta.DataType)])
return nil, fmt.Errorf("unsupported data type %s", schemapb.DataType_name[int32(fieldMeta.DataType)])
}
}
@ -1090,32 +1058,55 @@ func (q *queryCollection) search(msg queryMsg) error {
return nil
}
func (q *queryCollection) fillVectorOutputFieldsIfNeeded(msg queryMsg, segment *Segment, result *segcorepb.RetrieveResults) error {
// result is not empty
if len(result.Offset) <= 0 {
return nil
}
// get all vector output field ids
vecOutputFieldIDs, err := q.getVectorOutputFieldIDs(msg)
if err != nil {
return err
}
// output_fields contain vector field
for _, vecOutputFieldID := range vecOutputFieldIDs {
log.Debug("CYD - ", zap.Int64("vecOutputFieldID", vecOutputFieldID))
vecFieldInfo, err := segment.getVectorFieldInfo(vecOutputFieldID)
func (q *queryCollection) fillVectorFieldsData(segment *Segment, result *segcorepb.RetrieveResults) error {
for _, resultFieldData := range result.FieldsData {
vecFieldInfo, err := segment.getVectorFieldInfo(resultFieldData.FieldId)
if err != nil {
return fmt.Errorf("cannot get vector field info, fileID %d", vecOutputFieldID)
continue
}
// vector field raw data is not loaded into memory
if !vecFieldInfo.getRawDataInMemory() {
if err = q.historical.loader.loadSegmentVectorFieldsData(vecFieldInfo); err != nil {
return err
}
if err = segment.fillRetrieveResults(result, vecOutputFieldID, vecFieldInfo); err != nil {
return err
// if vector raw data is in memory, result should has been filled in valid vector raw data
if vecFieldInfo.getRawDataInMemory() {
continue
}
// load vector field data
if err = q.historical.loader.loadSegmentVectorFieldData(vecFieldInfo); err != nil {
return err
}
for i, offset := range result.Offset {
var success bool
for _, path := range vecFieldInfo.fieldBinlog.Binlogs {
rawData := vecFieldInfo.getRawData(path)
var numRows, dim int64
switch fieldData := rawData.(type) {
case *storage.FloatVectorFieldData:
numRows = int64(fieldData.NumRows)
dim = int64(fieldData.Dim)
if offset < numRows {
copy(resultFieldData.GetVectors().GetFloatVector().Data[int64(i)*dim:int64(i+1)*dim], fieldData.Data[offset*dim:(offset+1)*dim])
success = true
} else {
offset -= numRows
}
case *storage.BinaryVectorFieldData:
numRows = int64(fieldData.NumRows)
dim = int64(fieldData.Dim)
if offset < numRows {
x := resultFieldData.GetVectors().GetData().(*schemapb.VectorField_BinaryVector)
copy(x.BinaryVector[int64(i)*dim/8:int64(i+1)*dim/8], fieldData.Data[offset*dim/8:(offset+1)*dim/8])
success = true
} else {
offset -= numRows
}
default:
return fmt.Errorf("unexpected field data type")
}
if success {
break
}
}
}
}
@ -1200,7 +1191,7 @@ func (q *queryCollection) retrieve(msg queryMsg) error {
return err
}
if err = q.fillVectorOutputFieldsIfNeeded(msg, segment, result); err != nil {
if err = q.fillVectorFieldsData(segment, result); err != nil {
return err
}
mergeList = append(mergeList, result)

View File

@ -35,7 +35,6 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
)
@ -322,50 +321,6 @@ func (s *Segment) getEntityByIds(plan *RetrievePlan) (*segcorepb.RetrieveResults
return result, nil
}
func (s *Segment) fillRetrieveResults(result *segcorepb.RetrieveResults, fieldID int64, fieldInfo *VectorFieldInfo) error {
for _, resultFieldData := range result.FieldsData {
if resultFieldData.FieldId != fieldID {
continue
}
for i, offset := range result.Offset {
var success bool
for _, path := range fieldInfo.fieldBinlog.Binlogs {
rawData := fieldInfo.getRawData(path)
var numRows, dim int64
switch fieldData := rawData.(type) {
case *storage.FloatVectorFieldData:
numRows = int64(fieldData.NumRows)
dim = int64(fieldData.Dim)
if offset < numRows {
copy(resultFieldData.GetVectors().GetFloatVector().Data[int64(i)*dim:int64(i+1)*dim], fieldData.Data[offset*dim:(offset+1)*dim])
success = true
} else {
offset -= numRows
}
case *storage.BinaryVectorFieldData:
numRows = int64(fieldData.NumRows)
dim = int64(fieldData.Dim)
if offset < numRows {
x := resultFieldData.GetVectors().GetData().(*schemapb.VectorField_BinaryVector)
copy(x.BinaryVector[int64(i)*dim/8:int64(i+1)*dim/8], fieldData.Data[offset*dim/8:(offset+1)*dim/8])
success = true
} else {
offset -= numRows
}
default:
return errors.New("unexpected field data type")
}
if success {
break
}
}
}
}
return nil
}
func (s *Segment) fillTargetEntry(plan *SearchPlan, result *SearchResult) error {
if s.segmentPtr == nil {
return errors.New("null seg core pointer")

View File

@ -301,7 +301,7 @@ func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlog
return nil
}
func (loader *segmentLoader) loadSegmentVectorFieldsData(info *VectorFieldInfo) error {
func (loader *segmentLoader) loadSegmentVectorFieldData(info *VectorFieldInfo) error {
iCodec := storage.InsertCodec{}
defer func() {
err := iCodec.Close()
@ -314,6 +314,8 @@ func (loader *segmentLoader) loadSegmentVectorFieldsData(info *VectorFieldInfo)
continue
}
log.Debug("load vector raw data", zap.String("path", path))
binLog, err := loader.minioKV.Load(path)
if err != nil {
return err