Fix empty retrieve result (#21032) (#21066)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/21111/head
Jiquan Long 2022-12-09 16:37:20 +08:00 committed by GitHub
parent 714782fce1
commit 41f7b10abb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 741 additions and 7 deletions

View File

@ -87,3 +87,7 @@ const (
const (
CollectionTTLConfigKey = "collection.ttl.seconds"
)
func IsSystemField(fieldID int64) bool {
return fieldID < StartOfUserFieldID
}

View File

@ -0,0 +1,40 @@
package common
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsSystemField(t *testing.T) {
type args struct {
fieldID int64
}
tests := []struct {
name string
args args
want bool
}{
{
args: args{fieldID: StartOfUserFieldID},
want: false,
},
{
args: args{fieldID: StartOfUserFieldID + 1},
want: false,
},
{
args: args{fieldID: TimeStampField},
want: true,
},
{
args: args{fieldID: RowIDField},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, IsSystemField(tt.args.fieldID), "IsSystemField(%v)", tt.args.fieldID)
})
}
}

View File

@ -50,6 +50,7 @@ type queryTask struct {
ids *schemapb.IDs
collectionName string
queryParams *queryParams
schema *schemapb.CollectionSchema
resultBuf chan *internalpb.RetrieveResults
toReduceResults []*internalpb.RetrieveResults
@ -110,6 +111,16 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio
return outputFieldIDs, nil
}
func filterSystemFields(outputFieldIDs []UniqueID) []UniqueID {
filtered := make([]UniqueID, 0, len(outputFieldIDs))
for _, outputFieldID := range outputFieldIDs {
if !common.IsSystemField(outputFieldID) {
filtered = append(filtered, outputFieldID)
}
}
return filtered
}
// parseQueryParams get limit and offset from queryParamsPair, both are optional.
func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, error) {
var (
@ -225,6 +236,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
}
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName)
t.schema = schema
if t.ids != nil {
pkField := ""
@ -348,7 +360,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
tr.CtxRecord(ctx, "reduceResultStart")
t.result, err = reduceRetrieveResults(ctx, t.toReduceResults, t.queryParams)
t.result, err = reduceRetrieveResultsAndFillIfEmpty(ctx, t.toReduceResults, t.queryParams, t.GetOutputFieldsId(), t.schema)
if err != nil {
return err
}
@ -493,6 +505,21 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
return ret, nil
}
func reduceRetrieveResultsAndFillIfEmpty(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams, outputFieldsID []int64, schema *schemapb.CollectionSchema) (*milvuspb.QueryResults, error) {
result, err := reduceRetrieveResults(ctx, retrieveResults, queryParams)
if err != nil {
return nil, err
}
// filter system fields.
filtered := filterSystemFields(outputFieldsID)
if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewMilvusResult(result), filtered, schema); err != nil {
return nil, fmt.Errorf("failed to fill retrieve results: %s", err.Error())
}
return result, nil
}
func (t *queryTask) TraceCtx() context.Context {
return t.ctx
}

View File

@ -651,3 +651,9 @@ func getFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType,
return fieldData
}
func Test_filterSystemFields(t *testing.T) {
outputFieldIDs := []UniqueID{common.RowIDField, common.TimeStampField, common.StartOfUserFieldID}
filtered := filterSystemFields(outputFieldIDs)
assert.ElementsMatch(t, []UniqueID{common.StartOfUserFieldID}, filtered)
}

View File

@ -1051,7 +1051,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
results = append(results, streamingResult)
ret, err2 := mergeInternalRetrieveResult(ctx, results, req.Req.GetLimit())
ret, err2 := mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), qs.collection.Schema())
if err2 != nil {
failRet.Status.Reason = err2.Error()
return failRet, nil
@ -1093,6 +1093,13 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
},
}
coll, err := node.metaReplica.getCollectionByID(req.GetReq().GetCollectionID())
if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error()
return failRet, nil
}
toMergeResults := make([]*internalpb.RetrieveResults, 0)
runningGp, runningCtx := errgroup.WithContext(ctx)
mu := &sync.Mutex{}
@ -1127,7 +1134,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if err := runningGp.Wait(); err != nil {
return failRet, nil
}
ret, err := mergeInternalRetrieveResult(ctx, toMergeResults, req.GetReq().GetLimit())
ret, err := mergeInternalRetrieveResultsAndFillIfEmpty(ctx, toMergeResults, req.GetReq().GetLimit(), req.GetReq().GetOutputFieldsId(), coll.Schema())
if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error()

View File

@ -372,6 +372,46 @@ func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
return ret, nil
}
func mergeSegcoreRetrieveResultsAndFillIfEmpty(
ctx context.Context,
retrieveResults []*segcorepb.RetrieveResults,
limit int64,
outputFieldsID []int64,
schema *schemapb.CollectionSchema,
) (*segcorepb.RetrieveResults, error) {
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, limit)
if err != nil {
return nil, err
}
if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewSegcoreResults(mergedResult), outputFieldsID, schema); err != nil {
return nil, fmt.Errorf("failed to fill segcore retrieve results: %s", err.Error())
}
return mergedResult, nil
}
func mergeInternalRetrieveResultsAndFillIfEmpty(
ctx context.Context,
retrieveResults []*internalpb.RetrieveResults,
limit int64,
outputFieldsID []int64,
schema *schemapb.CollectionSchema,
) (*internalpb.RetrieveResults, error) {
mergedResult, err := mergeInternalRetrieveResult(ctx, retrieveResults, limit)
if err != nil {
return nil, err
}
if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewInternalResult(mergedResult), outputFieldsID, schema); err != nil {
return nil, fmt.Errorf("failed to fill internal retrieve results: %s", err.Error())
}
return mergedResult, nil
}
// func printSearchResultData(data *schemapb.SearchResultData, header string) {
// size := len(data.Ids.GetIntId().Data)
// if size != len(data.Scores) {

View File

@ -61,7 +61,7 @@ func (q *queryTask) queryOnStreaming() error {
}
// check if collection has been released, check streaming since it's released first
_, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
coll, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
if err != nil {
return err
}
@ -85,7 +85,7 @@ func (q *queryTask) queryOnStreaming() error {
}
q.tr.RecordSpan()
mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults, q.iReq.GetLimit())
mergedResult, err := mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, sResults, q.iReq.GetLimit(), q.iReq.GetOutputFieldsId(), coll.Schema())
if err != nil {
return err
}
@ -107,7 +107,7 @@ func (q *queryTask) queryOnHistorical() error {
}
// check if collection has been released, check historical since it's released first
_, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
coll, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
if err != nil {
return err
}
@ -129,10 +129,11 @@ func (q *queryTask) queryOnHistorical() error {
return err
}
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, q.req.GetReq().GetLimit())
mergedResult, err := mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, retrieveResults, q.req.GetReq().GetLimit(), q.iReq.GetOutputFieldsId(), coll.Schema())
if err != nil {
return err
}
q.Ret = &internalpb.RetrieveResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Ids: mergedResult.Ids,

View File

@ -0,0 +1,163 @@
package typeutil
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
func fieldDataEmpty(data *schemapb.FieldData) bool {
if data == nil {
return true
}
switch realData := data.Field.(type) {
case *schemapb.FieldData_Scalars:
switch realScalars := realData.Scalars.Data.(type) {
case *schemapb.ScalarField_BoolData:
return len(realScalars.BoolData.GetData()) <= 0
case *schemapb.ScalarField_LongData:
return len(realScalars.LongData.GetData()) <= 0
case *schemapb.ScalarField_FloatData:
return len(realScalars.FloatData.GetData()) <= 0
case *schemapb.ScalarField_DoubleData:
return len(realScalars.DoubleData.GetData()) <= 0
case *schemapb.ScalarField_StringData:
return len(realScalars.StringData.GetData()) <= 0
}
case *schemapb.FieldData_Vectors:
switch realVectors := realData.Vectors.Data.(type) {
case *schemapb.VectorField_BinaryVector:
return len(realVectors.BinaryVector) <= 0
case *schemapb.VectorField_FloatVector:
return len(realVectors.FloatVector.Data) <= 0
}
}
return true
}
func genEmptyBoolFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: nil}},
},
},
FieldId: field.GetFieldID(),
}
}
func genEmptyIntFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: nil}},
},
},
FieldId: field.GetFieldID(),
}
}
func genEmptyFloatFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{FloatData: &schemapb.FloatArray{Data: nil}},
},
},
FieldId: field.GetFieldID(),
}
}
func genEmptyDoubleFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{DoubleData: &schemapb.DoubleArray{Data: nil}},
},
},
FieldId: field.GetFieldID(),
}
}
func genEmptyVarCharFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: nil}},
},
},
FieldId: field.GetFieldID(),
}
}
func genEmptyBinaryVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) {
dim, err := GetDim(field)
if err != nil {
return nil, err
}
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: nil,
},
},
},
FieldId: field.GetFieldID(),
}, nil
}
func genEmptyFloatVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) {
dim, err := GetDim(field)
if err != nil {
return nil, err
}
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{Data: nil},
},
},
},
FieldId: field.GetFieldID(),
}, nil
}
func genEmptyFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) {
dataType := field.GetDataType()
switch dataType {
case schemapb.DataType_Bool:
return genEmptyBoolFieldData(field), nil
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64:
return genEmptyIntFieldData(field), nil
case schemapb.DataType_Float:
return genEmptyFloatFieldData(field), nil
case schemapb.DataType_Double:
return genEmptyDoubleFieldData(field), nil
case schemapb.DataType_VarChar:
return genEmptyVarCharFieldData(field), nil
case schemapb.DataType_BinaryVector:
return genEmptyBinaryVectorFieldData(field)
case schemapb.DataType_FloatVector:
return genEmptyFloatVectorFieldData(field)
default:
return nil, fmt.Errorf("unsupported data type: %s", dataType.String())
}
}

View File

@ -0,0 +1,25 @@
package typeutil
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
// GetDim get dimension of field. Maybe also helpful outside.
func GetDim(field *schemapb.FieldSchema) (int64, error) {
if !IsVectorType(field.GetDataType()) {
return 0, fmt.Errorf("%s is not of vector type", field.GetDataType())
}
h := NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...))
dimStr, err := h.Get("dim")
if err != nil {
return 0, fmt.Errorf("dim not found")
}
dim, err := strconv.Atoi(dimStr)
if err != nil {
return 0, fmt.Errorf("invalid dimension: %s", dimStr)
}
return int64(dim), nil
}

View File

@ -0,0 +1,31 @@
package typeutil
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
)
type kvPairsHelper[K comparable, V any] struct {
kvPairs map[K]V
}
func (h *kvPairsHelper[K, V]) Get(k K) (V, error) {
v, ok := h.kvPairs[k]
if !ok {
return v, fmt.Errorf("%v not found", k)
}
return v, nil
}
func NewKvPairs(pairs []*commonpb.KeyValuePair) *kvPairsHelper[string, string] {
helper := &kvPairsHelper[string, string]{
kvPairs: make(map[string]string),
}
for _, pair := range pairs {
helper.kvPairs[pair.GetKey()] = pair.GetValue()
}
return helper
}

View File

@ -0,0 +1,20 @@
package typeutil
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/stretchr/testify/assert"
)
func TestNewKvPairs(t *testing.T) {
kvPairs := []*commonpb.KeyValuePair{
{Key: "dim", Value: "128"},
}
h := NewKvPairs(kvPairs)
v, err := h.Get("dim")
assert.NoError(t, err)
assert.Equal(t, "128", v)
_, err = h.Get("not_exist")
assert.Error(t, err)
}

View File

@ -0,0 +1,41 @@
package typeutil
import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
func preHandleEmptyResult(result RetrieveResults) {
result.PreHandle()
}
func appendFieldData(result RetrieveResults, fieldData *schemapb.FieldData) {
result.AppendFieldData(fieldData)
}
func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, schema *schemapb.CollectionSchema) error {
if !result.ResultEmpty() {
return nil
}
preHandleEmptyResult(result)
helper, err := CreateSchemaHelper(schema)
if err != nil {
return err
}
for _, outputFieldID := range outputFieldIds {
field, err := helper.GetFieldFromID(outputFieldID)
if err != nil {
return err
}
emptyFieldData, err := genEmptyFieldData(field)
if err != nil {
return err
}
appendFieldData(result, emptyFieldData)
}
return nil
}

View File

@ -0,0 +1,253 @@
package typeutil
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
)
func TestGenEmptyFieldData(t *testing.T) {
allTypes := []schemapb.DataType{
schemapb.DataType_Bool,
schemapb.DataType_Int8,
schemapb.DataType_Int16,
schemapb.DataType_Int32,
schemapb.DataType_Int64,
schemapb.DataType_Float,
schemapb.DataType_Double,
schemapb.DataType_VarChar,
}
allUnsupportedTypes := []schemapb.DataType{
schemapb.DataType_String,
schemapb.DataType_None,
}
vectorTypes := []schemapb.DataType{
schemapb.DataType_BinaryVector,
schemapb.DataType_FloatVector,
}
field := &schemapb.FieldSchema{Name: "field_name", FieldID: 100}
for _, dataType := range allTypes {
field.DataType = dataType
fieldData, err := genEmptyFieldData(field)
assert.NoError(t, err)
assert.Equal(t, dataType, fieldData.GetType())
assert.Equal(t, field.GetName(), fieldData.GetFieldName())
assert.True(t, fieldDataEmpty(fieldData))
assert.Equal(t, field.GetFieldID(), fieldData.GetFieldId())
}
for _, dataType := range allUnsupportedTypes {
field.DataType = dataType
_, err := genEmptyFieldData(field)
assert.Error(t, err)
}
// dim not found
for _, dataType := range vectorTypes {
field.DataType = dataType
_, err := genEmptyFieldData(field)
assert.Error(t, err)
}
field.TypeParams = []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}
for _, dataType := range vectorTypes {
field.DataType = dataType
fieldData, err := genEmptyFieldData(field)
assert.NoError(t, err)
assert.Equal(t, dataType, fieldData.GetType())
assert.Equal(t, field.GetName(), fieldData.GetFieldName())
assert.True(t, fieldDataEmpty(fieldData))
assert.Equal(t, field.GetFieldID(), fieldData.GetFieldId())
}
}
func TestFillIfEmpty(t *testing.T) {
t.Run("not empty, do nothing", func(t *testing.T) {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2},
},
},
},
}
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, nil)
assert.NoError(t, err)
})
t.Run("invalid schema", func(t *testing.T) {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: nil,
},
},
},
}
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, nil)
assert.Error(t, err)
})
t.Run("field not found", func(t *testing.T) {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: nil,
},
},
},
}
schema := &schemapb.CollectionSchema{
Name: "collection",
Description: "description",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
DataType: schemapb.DataType_Int64,
},
},
}
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{101}, schema)
assert.Error(t, err)
})
t.Run("unsupported data type", func(t *testing.T) {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: nil,
},
},
},
}
schema := &schemapb.CollectionSchema{
Name: "collection",
Description: "description",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
DataType: schemapb.DataType_String,
},
},
}
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100}, schema)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: nil,
},
},
},
}
schema := &schemapb.CollectionSchema{
Name: "collection",
Description: "description",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "field100",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 101,
Name: "field101",
DataType: schemapb.DataType_VarChar,
},
},
}
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, schema)
assert.NoError(t, err)
assert.Nil(t, result.GetOffset())
assert.Equal(t, 2, len(result.GetFieldsData()))
for _, fieldData := range result.GetFieldsData() {
assert.True(t, fieldDataEmpty(fieldData))
}
})
t.Run("normal case", func(t *testing.T) {
result := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: nil,
},
},
},
}
schema := &schemapb.CollectionSchema{
Name: "collection",
Description: "description",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "field100",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 101,
Name: "field101",
DataType: schemapb.DataType_VarChar,
},
},
}
err := FillRetrieveResultIfEmpty(NewInternalResult(result), []int64{100, 101}, schema)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
for _, fieldData := range result.GetFieldsData() {
assert.True(t, fieldDataEmpty(fieldData))
}
})
t.Run("normal case", func(t *testing.T) {
result := &milvuspb.QueryResults{
FieldsData: nil,
}
schema := &schemapb.CollectionSchema{
Name: "collection",
Description: "description",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "field100",
DataType: schemapb.DataType_Int64,
},
{
FieldID: 101,
Name: "field101",
DataType: schemapb.DataType_VarChar,
},
},
}
err := FillRetrieveResultIfEmpty(NewMilvusResult(result), []int64{100, 101}, schema)
assert.NoError(t, err)
assert.Equal(t, 2, len(result.GetFieldsData()))
for _, fieldData := range result.GetFieldsData() {
assert.True(t, fieldDataEmpty(fieldData))
}
})
}

View File

@ -0,0 +1,76 @@
package typeutil
import (
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
)
type RetrieveResults interface {
PreHandle()
ResultEmpty() bool
AppendFieldData(fieldData *schemapb.FieldData)
}
type segcoreResults struct {
result *segcorepb.RetrieveResults
}
func (r *segcoreResults) PreHandle() {
r.result.Offset = nil
r.result.FieldsData = nil
}
func (r *segcoreResults) AppendFieldData(fieldData *schemapb.FieldData) {
r.result.FieldsData = append(r.result.FieldsData, fieldData)
}
func (r *segcoreResults) ResultEmpty() bool {
return GetSizeOfIDs(r.result.GetIds()) <= 0
}
func NewSegcoreResults(result *segcorepb.RetrieveResults) RetrieveResults {
return &segcoreResults{result: result}
}
type internalResults struct {
result *internalpb.RetrieveResults
}
func (r *internalResults) PreHandle() {
r.result.FieldsData = nil
}
func (r *internalResults) AppendFieldData(fieldData *schemapb.FieldData) {
r.result.FieldsData = append(r.result.FieldsData, fieldData)
}
func (r *internalResults) ResultEmpty() bool {
return GetSizeOfIDs(r.result.GetIds()) <= 0
}
func NewInternalResult(result *internalpb.RetrieveResults) RetrieveResults {
return &internalResults{result: result}
}
type milvusResults struct {
result *milvuspb.QueryResults
}
func (r *milvusResults) PreHandle() {
r.result.FieldsData = nil
}
func (r *milvusResults) AppendFieldData(fieldData *schemapb.FieldData) {
r.result.FieldsData = append(r.result.FieldsData, fieldData)
}
func (r *milvusResults) ResultEmpty() bool {
// not very clear.
return len(r.result.GetFieldsData()) <= 0
}
func NewMilvusResult(result *milvuspb.QueryResults) RetrieveResults {
return &milvusResults{result: result}
}