mirror of https://github.com/milvus-io/milvus.git
Signed-off-by: longjiquan <jiquan.long@zilliz.com> Signed-off-by: longjiquan <jiquan.long@zilliz.com>pull/21111/head
parent
714782fce1
commit
41f7b10abb
|
@ -87,3 +87,7 @@ const (
|
|||
const (
|
||||
CollectionTTLConfigKey = "collection.ttl.seconds"
|
||||
)
|
||||
|
||||
func IsSystemField(fieldID int64) bool {
|
||||
return fieldID < StartOfUserFieldID
|
||||
}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsSystemField(t *testing.T) {
|
||||
type args struct {
|
||||
fieldID int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
args: args{fieldID: StartOfUserFieldID},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
args: args{fieldID: StartOfUserFieldID + 1},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
args: args{fieldID: TimeStampField},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{fieldID: RowIDField},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, IsSystemField(tt.args.fieldID), "IsSystemField(%v)", tt.args.fieldID)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -50,6 +50,7 @@ type queryTask struct {
|
|||
ids *schemapb.IDs
|
||||
collectionName string
|
||||
queryParams *queryParams
|
||||
schema *schemapb.CollectionSchema
|
||||
|
||||
resultBuf chan *internalpb.RetrieveResults
|
||||
toReduceResults []*internalpb.RetrieveResults
|
||||
|
@ -110,6 +111,16 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio
|
|||
return outputFieldIDs, nil
|
||||
}
|
||||
|
||||
func filterSystemFields(outputFieldIDs []UniqueID) []UniqueID {
|
||||
filtered := make([]UniqueID, 0, len(outputFieldIDs))
|
||||
for _, outputFieldID := range outputFieldIDs {
|
||||
if !common.IsSystemField(outputFieldID) {
|
||||
filtered = append(filtered, outputFieldID)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// parseQueryParams get limit and offset from queryParamsPair, both are optional.
|
||||
func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, error) {
|
||||
var (
|
||||
|
@ -225,6 +236,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
t.schema = schema
|
||||
|
||||
if t.ids != nil {
|
||||
pkField := ""
|
||||
|
@ -348,7 +360,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||
|
||||
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
|
||||
tr.CtxRecord(ctx, "reduceResultStart")
|
||||
t.result, err = reduceRetrieveResults(ctx, t.toReduceResults, t.queryParams)
|
||||
t.result, err = reduceRetrieveResultsAndFillIfEmpty(ctx, t.toReduceResults, t.queryParams, t.GetOutputFieldsId(), t.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -493,6 +505,21 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func reduceRetrieveResultsAndFillIfEmpty(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams, outputFieldsID []int64, schema *schemapb.CollectionSchema) (*milvuspb.QueryResults, error) {
|
||||
result, err := reduceRetrieveResults(ctx, retrieveResults, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// filter system fields.
|
||||
filtered := filterSystemFields(outputFieldsID)
|
||||
if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewMilvusResult(result), filtered, schema); err != nil {
|
||||
return nil, fmt.Errorf("failed to fill retrieve results: %s", err.Error())
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *queryTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
|
|
@ -651,3 +651,9 @@ func getFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType,
|
|||
|
||||
return fieldData
|
||||
}
|
||||
|
||||
func Test_filterSystemFields(t *testing.T) {
|
||||
outputFieldIDs := []UniqueID{common.RowIDField, common.TimeStampField, common.StartOfUserFieldID}
|
||||
filtered := filterSystemFields(outputFieldIDs)
|
||||
assert.ElementsMatch(t, []UniqueID{common.StartOfUserFieldID}, filtered)
|
||||
}
|
||||
|
|
|
@ -1051,7 +1051,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
|
|||
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
|
||||
|
||||
results = append(results, streamingResult)
|
||||
ret, err2 := mergeInternalRetrieveResult(ctx, results, req.Req.GetLimit())
|
||||
ret, err2 := mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), qs.collection.Schema())
|
||||
if err2 != nil {
|
||||
failRet.Status.Reason = err2.Error()
|
||||
return failRet, nil
|
||||
|
@ -1093,6 +1093,13 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
|||
},
|
||||
}
|
||||
|
||||
coll, err := node.metaReplica.getCollectionByID(req.GetReq().GetCollectionID())
|
||||
if err != nil {
|
||||
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
failRet.Status.Reason = err.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
toMergeResults := make([]*internalpb.RetrieveResults, 0)
|
||||
runningGp, runningCtx := errgroup.WithContext(ctx)
|
||||
mu := &sync.Mutex{}
|
||||
|
@ -1127,7 +1134,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
|||
if err := runningGp.Wait(); err != nil {
|
||||
return failRet, nil
|
||||
}
|
||||
ret, err := mergeInternalRetrieveResult(ctx, toMergeResults, req.GetReq().GetLimit())
|
||||
ret, err := mergeInternalRetrieveResultsAndFillIfEmpty(ctx, toMergeResults, req.GetReq().GetLimit(), req.GetReq().GetOutputFieldsId(), coll.Schema())
|
||||
if err != nil {
|
||||
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
failRet.Status.Reason = err.Error()
|
||||
|
|
|
@ -372,6 +372,46 @@ func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func mergeSegcoreRetrieveResultsAndFillIfEmpty(
|
||||
ctx context.Context,
|
||||
retrieveResults []*segcorepb.RetrieveResults,
|
||||
limit int64,
|
||||
outputFieldsID []int64,
|
||||
schema *schemapb.CollectionSchema,
|
||||
) (*segcorepb.RetrieveResults, error) {
|
||||
|
||||
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewSegcoreResults(mergedResult), outputFieldsID, schema); err != nil {
|
||||
return nil, fmt.Errorf("failed to fill segcore retrieve results: %s", err.Error())
|
||||
}
|
||||
|
||||
return mergedResult, nil
|
||||
}
|
||||
|
||||
func mergeInternalRetrieveResultsAndFillIfEmpty(
|
||||
ctx context.Context,
|
||||
retrieveResults []*internalpb.RetrieveResults,
|
||||
limit int64,
|
||||
outputFieldsID []int64,
|
||||
schema *schemapb.CollectionSchema,
|
||||
) (*internalpb.RetrieveResults, error) {
|
||||
|
||||
mergedResult, err := mergeInternalRetrieveResult(ctx, retrieveResults, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewInternalResult(mergedResult), outputFieldsID, schema); err != nil {
|
||||
return nil, fmt.Errorf("failed to fill internal retrieve results: %s", err.Error())
|
||||
}
|
||||
|
||||
return mergedResult, nil
|
||||
}
|
||||
|
||||
// func printSearchResultData(data *schemapb.SearchResultData, header string) {
|
||||
// size := len(data.Ids.GetIntId().Data)
|
||||
// if size != len(data.Scores) {
|
||||
|
|
|
@ -61,7 +61,7 @@ func (q *queryTask) queryOnStreaming() error {
|
|||
}
|
||||
|
||||
// check if collection has been released, check streaming since it's released first
|
||||
_, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
|
||||
coll, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func (q *queryTask) queryOnStreaming() error {
|
|||
}
|
||||
|
||||
q.tr.RecordSpan()
|
||||
mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults, q.iReq.GetLimit())
|
||||
mergedResult, err := mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, sResults, q.iReq.GetLimit(), q.iReq.GetOutputFieldsId(), coll.Schema())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ func (q *queryTask) queryOnHistorical() error {
|
|||
}
|
||||
|
||||
// check if collection has been released, check historical since it's released first
|
||||
_, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
|
||||
coll, err := q.QS.metaReplica.getCollectionByID(q.CollectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -129,10 +129,11 @@ func (q *queryTask) queryOnHistorical() error {
|
|||
return err
|
||||
}
|
||||
|
||||
mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, q.req.GetReq().GetLimit())
|
||||
mergedResult, err := mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, retrieveResults, q.req.GetReq().GetLimit(), q.iReq.GetOutputFieldsId(), coll.Schema())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.Ret = &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Ids: mergedResult.Ids,
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
func fieldDataEmpty(data *schemapb.FieldData) bool {
|
||||
if data == nil {
|
||||
return true
|
||||
}
|
||||
switch realData := data.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
switch realScalars := realData.Scalars.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
return len(realScalars.BoolData.GetData()) <= 0
|
||||
case *schemapb.ScalarField_LongData:
|
||||
return len(realScalars.LongData.GetData()) <= 0
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
return len(realScalars.FloatData.GetData()) <= 0
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
return len(realScalars.DoubleData.GetData()) <= 0
|
||||
case *schemapb.ScalarField_StringData:
|
||||
return len(realScalars.StringData.GetData()) <= 0
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
switch realVectors := realData.Vectors.Data.(type) {
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
return len(realVectors.BinaryVector) <= 0
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
return len(realVectors.FloatVector.Data) <= 0
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func genEmptyBoolFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: nil}},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}
|
||||
}
|
||||
|
||||
func genEmptyIntFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: nil}},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}
|
||||
}
|
||||
|
||||
func genEmptyFloatFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{FloatData: &schemapb.FloatArray{Data: nil}},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}
|
||||
}
|
||||
|
||||
func genEmptyDoubleFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{DoubleData: &schemapb.DoubleArray{Data: nil}},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}
|
||||
}
|
||||
|
||||
func genEmptyVarCharFieldData(field *schemapb.FieldSchema) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: nil}},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}
|
||||
}
|
||||
|
||||
func genEmptyBinaryVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) {
|
||||
dim, err := GetDim(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func genEmptyFloatVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) {
|
||||
dim, err := GetDim(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{Data: nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func genEmptyFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) {
|
||||
dataType := field.GetDataType()
|
||||
switch dataType {
|
||||
case schemapb.DataType_Bool:
|
||||
return genEmptyBoolFieldData(field), nil
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64:
|
||||
return genEmptyIntFieldData(field), nil
|
||||
case schemapb.DataType_Float:
|
||||
return genEmptyFloatFieldData(field), nil
|
||||
case schemapb.DataType_Double:
|
||||
return genEmptyDoubleFieldData(field), nil
|
||||
case schemapb.DataType_VarChar:
|
||||
return genEmptyVarCharFieldData(field), nil
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return genEmptyBinaryVectorFieldData(field)
|
||||
case schemapb.DataType_FloatVector:
|
||||
return genEmptyFloatVectorFieldData(field)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported data type: %s", dataType.String())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
// GetDim get dimension of field. Maybe also helpful outside.
|
||||
func GetDim(field *schemapb.FieldSchema) (int64, error) {
|
||||
if !IsVectorType(field.GetDataType()) {
|
||||
return 0, fmt.Errorf("%s is not of vector type", field.GetDataType())
|
||||
}
|
||||
h := NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...))
|
||||
dimStr, err := h.Get("dim")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("dim not found")
|
||||
}
|
||||
dim, err := strconv.Atoi(dimStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid dimension: %s", dimStr)
|
||||
}
|
||||
return int64(dim), nil
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
)
|
||||
|
||||
type kvPairsHelper[K comparable, V any] struct {
|
||||
kvPairs map[K]V
|
||||
}
|
||||
|
||||
func (h *kvPairsHelper[K, V]) Get(k K) (V, error) {
|
||||
v, ok := h.kvPairs[k]
|
||||
if !ok {
|
||||
return v, fmt.Errorf("%v not found", k)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func NewKvPairs(pairs []*commonpb.KeyValuePair) *kvPairsHelper[string, string] {
|
||||
helper := &kvPairsHelper[string, string]{
|
||||
kvPairs: make(map[string]string),
|
||||
}
|
||||
|
||||
for _, pair := range pairs {
|
||||
helper.kvPairs[pair.GetKey()] = pair.GetValue()
|
||||
}
|
||||
|
||||
return helper
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewKvPairs(t *testing.T) {
|
||||
kvPairs := []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "128"},
|
||||
}
|
||||
h := NewKvPairs(kvPairs)
|
||||
v, err := h.Get("dim")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "128", v)
|
||||
_, err = h.Get("not_exist")
|
||||
assert.Error(t, err)
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
func preHandleEmptyResult(result RetrieveResults) {
|
||||
result.PreHandle()
|
||||
}
|
||||
|
||||
func appendFieldData(result RetrieveResults, fieldData *schemapb.FieldData) {
|
||||
result.AppendFieldData(fieldData)
|
||||
}
|
||||
|
||||
func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, schema *schemapb.CollectionSchema) error {
|
||||
if !result.ResultEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
preHandleEmptyResult(result)
|
||||
|
||||
helper, err := CreateSchemaHelper(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, outputFieldID := range outputFieldIds {
|
||||
field, err := helper.GetFieldFromID(outputFieldID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
emptyFieldData, err := genEmptyFieldData(field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
appendFieldData(result, emptyFieldData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,253 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
func TestGenEmptyFieldData(t *testing.T) {
|
||||
allTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar,
|
||||
}
|
||||
allUnsupportedTypes := []schemapb.DataType{
|
||||
schemapb.DataType_String,
|
||||
schemapb.DataType_None,
|
||||
}
|
||||
vectorTypes := []schemapb.DataType{
|
||||
schemapb.DataType_BinaryVector,
|
||||
schemapb.DataType_FloatVector,
|
||||
}
|
||||
|
||||
field := &schemapb.FieldSchema{Name: "field_name", FieldID: 100}
|
||||
for _, dataType := range allTypes {
|
||||
field.DataType = dataType
|
||||
fieldData, err := genEmptyFieldData(field)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataType, fieldData.GetType())
|
||||
assert.Equal(t, field.GetName(), fieldData.GetFieldName())
|
||||
assert.True(t, fieldDataEmpty(fieldData))
|
||||
assert.Equal(t, field.GetFieldID(), fieldData.GetFieldId())
|
||||
}
|
||||
|
||||
for _, dataType := range allUnsupportedTypes {
|
||||
field.DataType = dataType
|
||||
_, err := genEmptyFieldData(field)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// dim not found
|
||||
for _, dataType := range vectorTypes {
|
||||
field.DataType = dataType
|
||||
_, err := genEmptyFieldData(field)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
field.TypeParams = []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}
|
||||
for _, dataType := range vectorTypes {
|
||||
field.DataType = dataType
|
||||
fieldData, err := genEmptyFieldData(field)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataType, fieldData.GetType())
|
||||
assert.Equal(t, field.GetName(), fieldData.GetFieldName())
|
||||
assert.True(t, fieldDataEmpty(fieldData))
|
||||
assert.Equal(t, field.GetFieldID(), fieldData.GetFieldId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillIfEmpty(t *testing.T) {
|
||||
t.Run("not empty, do nothing", func(t *testing.T) {
|
||||
result := &segcorepb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, nil)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid schema", func(t *testing.T) {
|
||||
result := &segcorepb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("field not found", func(t *testing.T) {
|
||||
result := &segcorepb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "collection",
|
||||
Description: "description",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{101}, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported data type", func(t *testing.T) {
|
||||
result := &segcorepb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "collection",
|
||||
Description: "description",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_String,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100}, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
result := &segcorepb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "collection",
|
||||
Description: "description",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "field100",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "field101",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, result.GetOffset())
|
||||
assert.Equal(t, 2, len(result.GetFieldsData()))
|
||||
for _, fieldData := range result.GetFieldsData() {
|
||||
assert.True(t, fieldDataEmpty(fieldData))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
result := &internalpb.RetrieveResults{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "collection",
|
||||
Description: "description",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "field100",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "field101",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewInternalResult(result), []int64{100, 101}, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(result.GetFieldsData()))
|
||||
for _, fieldData := range result.GetFieldsData() {
|
||||
assert.True(t, fieldDataEmpty(fieldData))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
result := &milvuspb.QueryResults{
|
||||
FieldsData: nil,
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "collection",
|
||||
Description: "description",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "field100",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "field101",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := FillRetrieveResultIfEmpty(NewMilvusResult(result), []int64{100, 101}, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(result.GetFieldsData()))
|
||||
for _, fieldData := range result.GetFieldsData() {
|
||||
assert.True(t, fieldDataEmpty(fieldData))
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
)
|
||||
|
||||
type RetrieveResults interface {
|
||||
PreHandle()
|
||||
ResultEmpty() bool
|
||||
AppendFieldData(fieldData *schemapb.FieldData)
|
||||
}
|
||||
|
||||
type segcoreResults struct {
|
||||
result *segcorepb.RetrieveResults
|
||||
}
|
||||
|
||||
func (r *segcoreResults) PreHandle() {
|
||||
r.result.Offset = nil
|
||||
r.result.FieldsData = nil
|
||||
}
|
||||
|
||||
func (r *segcoreResults) AppendFieldData(fieldData *schemapb.FieldData) {
|
||||
r.result.FieldsData = append(r.result.FieldsData, fieldData)
|
||||
}
|
||||
|
||||
func (r *segcoreResults) ResultEmpty() bool {
|
||||
return GetSizeOfIDs(r.result.GetIds()) <= 0
|
||||
}
|
||||
|
||||
func NewSegcoreResults(result *segcorepb.RetrieveResults) RetrieveResults {
|
||||
return &segcoreResults{result: result}
|
||||
}
|
||||
|
||||
type internalResults struct {
|
||||
result *internalpb.RetrieveResults
|
||||
}
|
||||
|
||||
func (r *internalResults) PreHandle() {
|
||||
r.result.FieldsData = nil
|
||||
}
|
||||
|
||||
func (r *internalResults) AppendFieldData(fieldData *schemapb.FieldData) {
|
||||
r.result.FieldsData = append(r.result.FieldsData, fieldData)
|
||||
}
|
||||
|
||||
func (r *internalResults) ResultEmpty() bool {
|
||||
return GetSizeOfIDs(r.result.GetIds()) <= 0
|
||||
}
|
||||
|
||||
func NewInternalResult(result *internalpb.RetrieveResults) RetrieveResults {
|
||||
return &internalResults{result: result}
|
||||
}
|
||||
|
||||
type milvusResults struct {
|
||||
result *milvuspb.QueryResults
|
||||
}
|
||||
|
||||
func (r *milvusResults) PreHandle() {
|
||||
r.result.FieldsData = nil
|
||||
}
|
||||
|
||||
func (r *milvusResults) AppendFieldData(fieldData *schemapb.FieldData) {
|
||||
r.result.FieldsData = append(r.result.FieldsData, fieldData)
|
||||
}
|
||||
|
||||
func (r *milvusResults) ResultEmpty() bool {
|
||||
// not very clear.
|
||||
return len(r.result.GetFieldsData()) <= 0
|
||||
}
|
||||
|
||||
func NewMilvusResult(result *milvuspb.QueryResults) RetrieveResults {
|
||||
return &milvusResults{result: result}
|
||||
}
|
Loading…
Reference in New Issue