Implement SearchByPks path for Search (#25882)

Signed-off-by: unfode <forrest.futao.wei@gmail.com>
pull/27724/head
Futao Wei 2023-10-16 03:34:08 -04:00 committed by GitHub
parent e3f2122618
commit 599012a340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 612 additions and 87 deletions

View File

@ -103,7 +103,8 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collectionName", request.CollectionName),
zap.Int64("collectionID", request.CollectionID))
zap.Int64("collectionID", request.CollectionID),
)
log.Info("received request to invalidate collection meta cache")
@ -144,7 +145,11 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
method := "CreateDatabase"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
cct := &createDatabaseTask{
ctx: ctx,
@ -153,9 +158,11 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
rootCoord: node.rootCoord,
}
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
log := log.With(
zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole),
zap.String("dbName", request.DbName))
zap.String("dbName", request.DbName),
)
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(cct); err != nil {
@ -174,8 +181,17 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
}
log.Info(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
return cct.result, nil
}
@ -189,7 +205,11 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
method := "DropDatabase"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
dct := &dropDatabaseTask{
ctx: ctx,
@ -198,9 +218,11 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
rootCoord: node.rootCoord,
}
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
log := log.With(
zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole),
zap.String("dbName", request.DbName))
zap.String("dbName", request.DbName),
)
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(dct); err != nil {
@ -217,8 +239,17 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
}
log.Info(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
return dct.result, nil
}
@ -234,7 +265,11 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData
method := "ListDatabases"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
dct := &listDatabaseTask{
ctx: ctx,
@ -243,8 +278,10 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData
rootCoord: node.rootCoord,
}
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole))
log := log.With(
zap.String("traceID", sp.SpanContext().TraceID().String()),
zap.String("role", typeutil.ProxyRole),
)
log.Info(rpcReceived(method))
@ -264,8 +301,17 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData
}
log.Info(rpcDone(method), zap.Int("num of db", len(dct.result.DbNames)))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
return dct.result, nil
}
@ -281,7 +327,11 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
method := "CreateCollection"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
cct := &createCollectionTask{
ctx: ctx,
@ -299,7 +349,8 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
zap.String("collection", request.CollectionName),
zap.Int("len(schema)", lenOfSchema),
zap.Int32("shards_num", request.ShardsNum),
zap.String("consistency_level", request.ConsistencyLevel.String()))
zap.String("consistency_level", request.ConsistencyLevel.String()),
)
log.Debug(rpcReceived(method))
@ -316,7 +367,8 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
rpcEnqueued(method),
zap.Uint64("BeginTs", cct.BeginTs()),
zap.Uint64("EndTs", cct.EndTs()),
zap.Uint64("timestamp", request.Base.Timestamp))
zap.Uint64("timestamp", request.Base.Timestamp),
)
if err := cct.WaitToFinish(); err != nil {
log.Warn(
@ -332,10 +384,19 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
log.Debug(
rpcDone(method),
zap.Uint64("BeginTs", cct.BeginTs()),
zap.Uint64("EndTs", cct.EndTs()))
zap.Uint64("EndTs", cct.EndTs()),
)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return cct.result, nil
}
@ -349,7 +410,11 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
defer sp.End()
method := "DropCollection"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
dct := &dropCollectionTask{
ctx: ctx,
@ -363,7 +428,8 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
zap.String("collection", request.CollectionName),
)
log.Debug("DropCollection received")
@ -375,9 +441,11 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
return merr.Status(err), nil
}
log.Debug("DropCollection enqueued",
log.Debug(
"DropCollection enqueued",
zap.Uint64("BeginTs", dct.BeginTs()),
zap.Uint64("EndTs", dct.EndTs()))
zap.Uint64("EndTs", dct.EndTs()),
)
if err := dct.WaitToFinish(); err != nil {
log.Warn("DropCollection failed to WaitToFinish",
@ -389,12 +457,22 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
return merr.Status(err), nil
}
log.Debug("DropCollection done",
log.Debug(
"DropCollection done",
zap.Uint64("BeginTs", dct.BeginTs()),
zap.Uint64("EndTs", dct.EndTs()))
zap.Uint64("EndTs", dct.EndTs()),
)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return dct.result, nil
}
@ -410,13 +488,17 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
defer sp.End()
method := "HasCollection"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
zap.String("collection", request.CollectionName),
)
log.Debug("HasCollection received")
@ -438,9 +520,11 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
}, nil
}
log.Debug("HasCollection enqueued",
log.Debug(
"HasCollection enqueued",
zap.Uint64("BeginTS", hct.BeginTs()),
zap.Uint64("EndTS", hct.EndTs()))
zap.Uint64("EndTS", hct.EndTs()),
)
if err := hct.WaitToFinish(); err != nil {
log.Warn("HasCollection failed to WaitToFinish",
@ -455,13 +539,22 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
}, nil
}
log.Debug("HasCollection done",
log.Debug(
"HasCollection done",
zap.Uint64("BeginTS", hct.BeginTs()),
zap.Uint64("EndTS", hct.EndTs()))
zap.Uint64("EndTS", hct.EndTs()),
)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return hct.result, nil
}
@ -475,8 +568,12 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
defer sp.End()
method := "LoadCollection"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
lct := &loadCollectionTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
@ -489,7 +586,8 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Bool("refreshMode", request.Refresh))
zap.Bool("refreshMode", request.Refresh),
)
log.Debug("LoadCollection received")
@ -502,9 +600,11 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
return merr.Status(err), nil
}
log.Debug("LoadCollection enqueued",
log.Debug(
"LoadCollection enqueued",
zap.Uint64("BeginTS", lct.BeginTs()),
zap.Uint64("EndTS", lct.EndTs()))
zap.Uint64("EndTS", lct.EndTs()),
)
if err := lct.WaitToFinish(); err != nil {
log.Warn("LoadCollection failed to WaitToFinish",
@ -516,13 +616,22 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
return merr.Status(err), nil
}
log.Debug("LoadCollection done",
log.Debug(
"LoadCollection done",
zap.Uint64("BeginTS", lct.BeginTs()),
zap.Uint64("EndTS", lct.EndTs()))
zap.Uint64("EndTS", lct.EndTs()),
)
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxyReqLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return lct.result, nil
}
@ -2379,11 +2488,15 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel, request.GetCollectionName()).Add(float64(receiveSize))
metrics.SearchLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))
metrics.ProxyReceivedNQ.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel, request.GetCollectionName()).Add(float64(request.GetNq()))
metrics.SearchLabel,
request.GetCollectionName(),
).Add(float64(request.GetNq()))
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()))
@ -2392,14 +2505,29 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
Status: merr.Status(err),
}, nil
}
method := "Search"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search")
defer sp.End()
if request.SearchByPrimaryKeys {
placeholderGroupBytes, err := node.getVectorPlaceholderGroupForSearchByPks(ctx, request)
if err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}
request.PlaceholderGroup = placeholderGroupBytes
}
qt := &searchTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
@ -2428,7 +2556,8 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)),
zap.Any("OutputFields", request.OutputFields),
zap.Any("search_params", request.SearchParams),
zap.Uint64("guarantee_timestamp", guaranteeTs))
zap.Uint64("guarantee_timestamp", guaranteeTs),
)
defer func() {
span := tr.ElapseSpan()
@ -2442,10 +2571,14 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn(
rpcFailedToEnqueue(method),
zap.Error(err))
zap.Error(err),
)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.AbandonLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.AbandonLabel,
).Inc()
return &milvuspb.SearchResults{
Status: merr.Status(err),
@ -2455,15 +2588,20 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
log.Debug(
rpcEnqueued(method),
zap.Uint64("timestamp", qt.Base.Timestamp))
zap.Uint64("timestamp", qt.Base.Timestamp),
)
if err := qt.WaitToFinish(); err != nil {
log.Warn(
rpcFailedToWaitToFinish(method),
zap.Error(err))
zap.Error(err),
)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.FailLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.FailLabel,
).Inc()
return &milvuspb.SearchResults{
Status: merr.Status(err),
@ -2471,19 +2609,34 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
}
span := tr.CtxRecord(ctx, "wait search result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(span.Milliseconds()))
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
).Observe(float64(span.Milliseconds()))
tr.CtxRecord(ctx, "wait search result")
log.Debug(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(qt.result.GetResults().GetNumQueries()))
searchDur := tr.ElapseSpan().Milliseconds()
metrics.ProxySQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(searchDur))
metrics.ProxyCollectionSQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel, request.CollectionName).Observe(float64(searchDur))
metrics.ProxySQLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
).Observe(float64(searchDur))
metrics.ProxyCollectionSQLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.CollectionName,
).Observe(float64(searchDur))
if qt.result != nil {
sentSize := proto.Size(qt.result)
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
@ -2492,6 +2645,61 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
return qt.result, nil
}
func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) {
placeholderGroup := &commonpb.PlaceholderGroup{}
err := proto.Unmarshal(request.PlaceholderGroup, placeholderGroup)
if err != nil {
return nil, err
}
if len(placeholderGroup.Placeholders) != 1 || len(placeholderGroup.Placeholders[0].Values) != 1 {
return nil, merr.WrapErrParameterInvalidMsg("please provide primary key")
}
queryExpr := string(placeholderGroup.Placeholders[0].Values[0])
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, request.SearchParams)
if err != nil {
return nil, err
}
queryRequest := &milvuspb.QueryRequest{
Base: request.Base,
DbName: request.DbName,
CollectionName: request.CollectionName,
Expr: queryExpr,
OutputFields: []string{annsField},
PartitionNames: request.PartitionNames,
TravelTimestamp: request.TravelTimestamp,
GuaranteeTimestamp: request.GuaranteeTimestamp,
QueryParams: nil,
NotReturnAllMeta: request.NotReturnAllMeta,
ConsistencyLevel: request.ConsistencyLevel,
UseDefaultConsistency: request.UseDefaultConsistency,
}
queryResults, _ := node.Query(ctx, queryRequest)
err = merr.Error(queryResults.GetStatus())
if err != nil {
return nil, err
}
var vectorFieldsData *schemapb.FieldData
for _, fieldsData := range queryResults.GetFieldsData() {
if fieldsData.GetFieldName() == annsField {
vectorFieldsData = fieldsData
break
}
}
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(vectorFieldsData)
if err != nil {
return nil, err
}
return placeholderGroupBytes, nil
}
// Flush notify data nodes to persist the data of collection.
func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) {
resp := &milvuspb.FlushResponse{
@ -2567,11 +2775,15 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel, request.GetCollectionName()).Add(float64(receiveSize))
metrics.QueryLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))
metrics.ProxyReceivedNQ.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel, request.GetCollectionName()).Add(float64(1))
metrics.SearchLabel,
request.GetCollectionName(),
).Add(float64(1))
rateCol.Add(internalpb.RateType_DQLQuery.String(), 1)
@ -2602,14 +2814,18 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
method := "Query"
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.TotalLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.TotalLabel,
).Inc()
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Strings("partitions", request.PartitionNames))
zap.Strings("partitions", request.PartitionNames),
)
defer func() {
span := tr.ElapseSpan()
@ -2629,15 +2845,20 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
zap.String("expr", request.Expr),
zap.Strings("OutputFields", request.OutputFields),
zap.Uint64("travel_timestamp", request.TravelTimestamp),
zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp))
zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp),
)
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn(
rpcFailedToEnqueue(method),
zap.Error(err))
zap.Error(err),
)
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.AbandonLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.AbandonLabel,
).Inc()
return &milvuspb.QueryResults{
Status: merr.Status(err),
@ -2660,21 +2881,34 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
}, nil
}
span := tr.CtxRecord(ctx, "wait query result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel).Observe(float64(span.Milliseconds()))
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel,
).Observe(float64(span.Milliseconds()))
log.Debug(rpcDone(method))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyFunctionCall.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
method,
metrics.SuccessLabel,
).Inc()
metrics.ProxySQLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyCollectionSQLatency.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel,
request.CollectionName,
).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxySQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyCollectionSQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.QueryLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
sentSize := proto.Size(qt.result)
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
return qt.result, nil
}

View File

@ -17,7 +17,10 @@
package proxy
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand"
"net"
@ -1005,6 +1008,7 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
var insertedIds []int64
wg.Add(1)
t.Run("insert", func(t *testing.T) {
defer wg.Done()
@ -1016,6 +1020,13 @@ func TestProxy(t *testing.T) {
assert.Equal(t, rowNum, len(resp.SuccIndex))
assert.Equal(t, 0, len(resp.ErrIndex))
assert.Equal(t, int64(rowNum), resp.InsertCnt)
switch field := resp.GetIDs().GetIdField().(type) {
case *schemapb.IDs_IntId:
insertedIds = field.IntId.GetData()
default:
t.Fatalf("Unexpected ID type")
}
})
// TODO(dragondriver): proxy.Delete()
@ -1328,6 +1339,135 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
nprobe := 10
topk := 10
roundDecimal := 6
expr := fmt.Sprintf("%s > 0", int64Field)
constructVectorsPlaceholderGroup := func() *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
assert.NoError(t, err)
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
constructSearchRequest := func() *milvuspb.SearchRequest {
plg := constructVectorsPlaceholderGroup()
plgBs, err := proto.Marshal(plg)
assert.NoError(t, err)
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
assert.NoError(t, err)
searchParams := []*commonpb.KeyValuePair{
{Key: MetricTypeKey, Value: metric.L2},
{Key: SearchParamsKey, Value: string(b)},
{Key: AnnsFieldKey, Value: floatVecField},
{Key: TopKKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: searchParams,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
SearchByPrimaryKeys: false,
}
}
wg.Add(1)
t.Run("search", func(t *testing.T) {
defer wg.Done()
req := constructSearchRequest()
resp, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
exprBytes := []byte(expr)
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_None,
Values: [][]byte{exprBytes},
},
},
}
}
constructSearchByPksRequest := func() *milvuspb.SearchRequest {
plg := constructPrimaryKeysPlaceholderGroup()
plgBs, err := proto.Marshal(plg)
assert.NoError(t, err)
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
assert.NoError(t, err)
searchParams := []*commonpb.KeyValuePair{
{Key: MetricTypeKey, Value: metric.L2},
{Key: SearchParamsKey, Value: string(b)},
{Key: AnnsFieldKey, Value: floatVecField},
{Key: TopKKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: "",
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: searchParams,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
SearchByPrimaryKeys: true,
}
}
wg.Add(1)
t.Run("search by primary keys", func(t *testing.T) {
defer wg.Done()
req := constructSearchByPksRequest()
resp, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
// nprobe := 10
// topk := 10
// roundDecimal := 6

View File

@ -0,0 +1,118 @@
package funcutil
import (
"encoding/binary"
"math"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
if err != nil {
return nil, err
}
placeholderGroup := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
}
bytes, _ := proto.Marshal(placeholderGroup)
return bytes, nil
}
func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.PlaceholderValue, error) {
switch fieldData.Type {
case schemapb.DataType_FloatVector:
vectors := fieldData.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector)
if !ok {
return nil, errors.New("vector data is not schemapb.VectorField_FloatVector")
}
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: flattenedFloatVectorsToByteVectors(x.FloatVector.Data, int(vectors.Dim)),
}
return placeholderValue, nil
case schemapb.DataType_BinaryVector:
vectors := fieldData.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector)
if !ok {
return nil, errors.New("vector data is not schemapb.VectorField_BinaryVector")
}
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_BinaryVector,
Values: flattenedByteVectorsToByteVectors(x.BinaryVector, int(vectors.Dim)),
}
return placeholderValue, nil
case schemapb.DataType_Float16Vector:
vectors := fieldData.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector)
if !ok {
return nil, errors.New("vector data is not schemapb.VectorField_Float16Vector")
}
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_Float16Vector,
Values: flattenedFloat16VectorsToByteVectors(x.Float16Vector, int(vectors.Dim)),
}
return placeholderValue, nil
default:
return nil, errors.New("field is not a vector field")
}
}
func flattenedFloatVectorsToByteVectors(flattenedVectors []float32, dimension int) [][]byte {
floatVectors := flattenedFloatVectorsToFloatVectors(flattenedVectors, dimension)
result := make([][]byte, 0)
for _, floatVector := range floatVectors {
result = append(result, floatVectorToByteVector(floatVector))
}
return result
}
func flattenedFloatVectorsToFloatVectors(flattenedVectors []float32, dimension int) [][]float32 {
result := make([][]float32, 0)
for i := 0; i < len(flattenedVectors); i += dimension {
result = append(result, flattenedVectors[i:i+dimension])
}
return result
}
func floatVectorToByteVector(vector []float32) []byte {
data := make([]byte, 0, 4*len(vector)) // float32 occupies 4 bytes
buf := make([]byte, 4)
for _, f := range vector {
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
data = append(data, buf...)
}
return data
}
func flattenedByteVectorsToByteVectors(flattenedVectors []byte, dimension int) [][]byte {
result := make([][]byte, 0)
for i := 0; i < len(flattenedVectors); i += dimension {
result = append(result, flattenedVectors[i:i+dimension])
}
return result
}
func flattenedFloat16VectorsToByteVectors(flattenedVectors []byte, dimension int) [][]byte {
result := make([][]byte, 0)
vectorBytes := 2 * dimension
for i := 0; i < len(flattenedVectors); i += vectorBytes {
result = append(result, flattenedVectors[i:i+vectorBytes])
}
return result
}

View File

@ -0,0 +1,33 @@
package funcutil
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_flattenedByteVectorsToByteVectors(t *testing.T) {
flattenedVectors := []byte{0, 1, 2, 3, 4, 5}
dimension := 3
actual := flattenedByteVectorsToByteVectors(flattenedVectors, dimension)
expected := [][]byte{
{0, 1, 2},
{3, 4, 5},
}
assert.Equal(t, expected, actual)
}
func Test_flattenedFloat16VectorsToByteVectors(t *testing.T) {
flattenedVectors := []byte{0, 1, 2, 3, 4, 5, 6, 7}
dimension := 2
actual := flattenedFloat16VectorsToByteVectors(flattenedVectors, dimension)
expected := [][]byte{
{0, 1, 2, 3},
{4, 5, 6, 7},
}
assert.Equal(t, expected, actual)
}