mirror of https://github.com/milvus-io/milvus.git
Implement SearchByPks path for Search (#25882)
Signed-off-by: unfode <forrest.futao.wei@gmail.com>pull/27724/head
parent
e3f2122618
commit
599012a340
|
@ -103,7 +103,8 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
|||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collectionName", request.CollectionName),
|
||||
zap.Int64("collectionID", request.CollectionID))
|
||||
zap.Int64("collectionID", request.CollectionID),
|
||||
)
|
||||
|
||||
log.Info("received request to invalidate collection meta cache")
|
||||
|
||||
|
@ -144,7 +145,11 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
|
|||
method := "CreateDatabase"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
cct := &createDatabaseTask{
|
||||
ctx: ctx,
|
||||
|
@ -153,9 +158,11 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
|
|||
rootCoord: node.rootCoord,
|
||||
}
|
||||
|
||||
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
|
||||
log := log.With(
|
||||
zap.String("traceID", sp.SpanContext().TraceID().String()),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("dbName", request.DbName))
|
||||
zap.String("dbName", request.DbName),
|
||||
)
|
||||
|
||||
log.Info(rpcReceived(method))
|
||||
if err := node.sched.ddQueue.Enqueue(cct); err != nil {
|
||||
|
@ -174,8 +181,17 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
|
|||
}
|
||||
|
||||
log.Info(rpcDone(method))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
return cct.result, nil
|
||||
}
|
||||
|
||||
|
@ -189,7 +205,11 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
|
|||
|
||||
method := "DropDatabase"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
dct := &dropDatabaseTask{
|
||||
ctx: ctx,
|
||||
|
@ -198,9 +218,11 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
|
|||
rootCoord: node.rootCoord,
|
||||
}
|
||||
|
||||
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
|
||||
log := log.With(
|
||||
zap.String("traceID", sp.SpanContext().TraceID().String()),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("dbName", request.DbName))
|
||||
zap.String("dbName", request.DbName),
|
||||
)
|
||||
|
||||
log.Info(rpcReceived(method))
|
||||
if err := node.sched.ddQueue.Enqueue(dct); err != nil {
|
||||
|
@ -217,8 +239,17 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
|
|||
}
|
||||
|
||||
log.Info(rpcDone(method))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
return dct.result, nil
|
||||
}
|
||||
|
||||
|
@ -234,7 +265,11 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData
|
|||
|
||||
method := "ListDatabases"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
dct := &listDatabaseTask{
|
||||
ctx: ctx,
|
||||
|
@ -243,8 +278,10 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData
|
|||
rootCoord: node.rootCoord,
|
||||
}
|
||||
|
||||
log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
log := log.With(
|
||||
zap.String("traceID", sp.SpanContext().TraceID().String()),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
)
|
||||
|
||||
log.Info(rpcReceived(method))
|
||||
|
||||
|
@ -264,8 +301,17 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData
|
|||
}
|
||||
|
||||
log.Info(rpcDone(method), zap.Int("num of db", len(dct.result.DbNames)))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
return dct.result, nil
|
||||
}
|
||||
|
||||
|
@ -281,7 +327,11 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
|
|||
method := "CreateCollection"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
cct := &createCollectionTask{
|
||||
ctx: ctx,
|
||||
|
@ -299,7 +349,8 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.Int("len(schema)", lenOfSchema),
|
||||
zap.Int32("shards_num", request.ShardsNum),
|
||||
zap.String("consistency_level", request.ConsistencyLevel.String()))
|
||||
zap.String("consistency_level", request.ConsistencyLevel.String()),
|
||||
)
|
||||
|
||||
log.Debug(rpcReceived(method))
|
||||
|
||||
|
@ -316,7 +367,8 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
|
|||
rpcEnqueued(method),
|
||||
zap.Uint64("BeginTs", cct.BeginTs()),
|
||||
zap.Uint64("EndTs", cct.EndTs()),
|
||||
zap.Uint64("timestamp", request.Base.Timestamp))
|
||||
zap.Uint64("timestamp", request.Base.Timestamp),
|
||||
)
|
||||
|
||||
if err := cct.WaitToFinish(); err != nil {
|
||||
log.Warn(
|
||||
|
@ -332,10 +384,19 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
|
|||
log.Debug(
|
||||
rpcDone(method),
|
||||
zap.Uint64("BeginTs", cct.BeginTs()),
|
||||
zap.Uint64("EndTs", cct.EndTs()))
|
||||
zap.Uint64("EndTs", cct.EndTs()),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return cct.result, nil
|
||||
}
|
||||
|
||||
|
@ -349,7 +410,11 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
|||
defer sp.End()
|
||||
method := "DropCollection"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
dct := &dropCollectionTask{
|
||||
ctx: ctx,
|
||||
|
@ -363,7 +428,8 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
|||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
zap.String("collection", request.CollectionName),
|
||||
)
|
||||
|
||||
log.Debug("DropCollection received")
|
||||
|
||||
|
@ -375,9 +441,11 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
|||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
log.Debug("DropCollection enqueued",
|
||||
log.Debug(
|
||||
"DropCollection enqueued",
|
||||
zap.Uint64("BeginTs", dct.BeginTs()),
|
||||
zap.Uint64("EndTs", dct.EndTs()))
|
||||
zap.Uint64("EndTs", dct.EndTs()),
|
||||
)
|
||||
|
||||
if err := dct.WaitToFinish(); err != nil {
|
||||
log.Warn("DropCollection failed to WaitToFinish",
|
||||
|
@ -389,12 +457,22 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
|||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
log.Debug("DropCollection done",
|
||||
log.Debug(
|
||||
"DropCollection done",
|
||||
zap.Uint64("BeginTs", dct.BeginTs()),
|
||||
zap.Uint64("EndTs", dct.EndTs()))
|
||||
zap.Uint64("EndTs", dct.EndTs()),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return dct.result, nil
|
||||
}
|
||||
|
||||
|
@ -410,13 +488,17 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
|
|||
defer sp.End()
|
||||
method := "HasCollection"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
zap.String("collection", request.CollectionName),
|
||||
)
|
||||
|
||||
log.Debug("HasCollection received")
|
||||
|
||||
|
@ -438,9 +520,11 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
|
|||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("HasCollection enqueued",
|
||||
log.Debug(
|
||||
"HasCollection enqueued",
|
||||
zap.Uint64("BeginTS", hct.BeginTs()),
|
||||
zap.Uint64("EndTS", hct.EndTs()))
|
||||
zap.Uint64("EndTS", hct.EndTs()),
|
||||
)
|
||||
|
||||
if err := hct.WaitToFinish(); err != nil {
|
||||
log.Warn("HasCollection failed to WaitToFinish",
|
||||
|
@ -455,13 +539,22 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
|
|||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("HasCollection done",
|
||||
log.Debug(
|
||||
"HasCollection done",
|
||||
zap.Uint64("BeginTS", hct.BeginTs()),
|
||||
zap.Uint64("EndTS", hct.EndTs()))
|
||||
zap.Uint64("EndTS", hct.EndTs()),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return hct.result, nil
|
||||
}
|
||||
|
||||
|
@ -475,8 +568,12 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
|
|||
defer sp.End()
|
||||
method := "LoadCollection"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
lct := &loadCollectionTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
|
@ -489,7 +586,8 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
|
|||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Bool("refreshMode", request.Refresh))
|
||||
zap.Bool("refreshMode", request.Refresh),
|
||||
)
|
||||
|
||||
log.Debug("LoadCollection received")
|
||||
|
||||
|
@ -502,9 +600,11 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
|
|||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
log.Debug("LoadCollection enqueued",
|
||||
log.Debug(
|
||||
"LoadCollection enqueued",
|
||||
zap.Uint64("BeginTS", lct.BeginTs()),
|
||||
zap.Uint64("EndTS", lct.EndTs()))
|
||||
zap.Uint64("EndTS", lct.EndTs()),
|
||||
)
|
||||
|
||||
if err := lct.WaitToFinish(); err != nil {
|
||||
log.Warn("LoadCollection failed to WaitToFinish",
|
||||
|
@ -516,13 +616,22 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
|
|||
return merr.Status(err), nil
|
||||
}
|
||||
|
||||
log.Debug("LoadCollection done",
|
||||
log.Debug(
|
||||
"LoadCollection done",
|
||||
zap.Uint64("BeginTS", lct.BeginTs()),
|
||||
zap.Uint64("EndTS", lct.EndTs()))
|
||||
zap.Uint64("EndTS", lct.EndTs()),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
return lct.result, nil
|
||||
}
|
||||
|
||||
|
@ -2379,11 +2488,15 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
receiveSize := proto.Size(request)
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel, request.GetCollectionName()).Add(float64(receiveSize))
|
||||
metrics.SearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(receiveSize))
|
||||
|
||||
metrics.ProxyReceivedNQ.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel, request.GetCollectionName()).Add(float64(request.GetNq()))
|
||||
metrics.SearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(request.GetNq()))
|
||||
|
||||
rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq()))
|
||||
|
||||
|
@ -2392,14 +2505,29 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
method := "Search"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search")
|
||||
defer sp.End()
|
||||
|
||||
if request.SearchByPrimaryKeys {
|
||||
placeholderGroupBytes, err := node.getVectorPlaceholderGroupForSearchByPks(ctx, request)
|
||||
if err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
request.PlaceholderGroup = placeholderGroupBytes
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
|
@ -2428,7 +2556,8 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)),
|
||||
zap.Any("OutputFields", request.OutputFields),
|
||||
zap.Any("search_params", request.SearchParams),
|
||||
zap.Uint64("guarantee_timestamp", guaranteeTs))
|
||||
zap.Uint64("guarantee_timestamp", guaranteeTs),
|
||||
)
|
||||
|
||||
defer func() {
|
||||
span := tr.ElapseSpan()
|
||||
|
@ -2442,10 +2571,14 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
|
||||
log.Warn(
|
||||
rpcFailedToEnqueue(method),
|
||||
zap.Error(err))
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.AbandonLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.AbandonLabel,
|
||||
).Inc()
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
|
@ -2455,15 +2588,20 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
|
||||
log.Debug(
|
||||
rpcEnqueued(method),
|
||||
zap.Uint64("timestamp", qt.Base.Timestamp))
|
||||
zap.Uint64("timestamp", qt.Base.Timestamp),
|
||||
)
|
||||
|
||||
if err := qt.WaitToFinish(); err != nil {
|
||||
log.Warn(
|
||||
rpcFailedToWaitToFinish(method),
|
||||
zap.Error(err))
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.FailLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.FailLabel,
|
||||
).Inc()
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
|
@ -2471,19 +2609,34 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
}
|
||||
|
||||
span := tr.CtxRecord(ctx, "wait search result")
|
||||
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel).Observe(float64(span.Milliseconds()))
|
||||
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel,
|
||||
).Observe(float64(span.Milliseconds()))
|
||||
|
||||
tr.CtxRecord(ctx, "wait search result")
|
||||
log.Debug(rpcDone(method))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(qt.result.GetResults().GetNumQueries()))
|
||||
|
||||
searchDur := tr.ElapseSpan().Milliseconds()
|
||||
metrics.ProxySQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel).Observe(float64(searchDur))
|
||||
metrics.ProxyCollectionSQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel, request.CollectionName).Observe(float64(searchDur))
|
||||
metrics.ProxySQLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel,
|
||||
).Observe(float64(searchDur))
|
||||
|
||||
metrics.ProxyCollectionSQLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel,
|
||||
request.CollectionName,
|
||||
).Observe(float64(searchDur))
|
||||
|
||||
if qt.result != nil {
|
||||
sentSize := proto.Size(qt.result)
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
|
||||
|
@ -2492,6 +2645,61 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
return qt.result, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) {
|
||||
placeholderGroup := &commonpb.PlaceholderGroup{}
|
||||
err := proto.Unmarshal(request.PlaceholderGroup, placeholderGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(placeholderGroup.Placeholders) != 1 || len(placeholderGroup.Placeholders[0].Values) != 1 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("please provide primary key")
|
||||
}
|
||||
queryExpr := string(placeholderGroup.Placeholders[0].Values[0])
|
||||
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, request.SearchParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryRequest := &milvuspb.QueryRequest{
|
||||
Base: request.Base,
|
||||
DbName: request.DbName,
|
||||
CollectionName: request.CollectionName,
|
||||
Expr: queryExpr,
|
||||
OutputFields: []string{annsField},
|
||||
PartitionNames: request.PartitionNames,
|
||||
TravelTimestamp: request.TravelTimestamp,
|
||||
GuaranteeTimestamp: request.GuaranteeTimestamp,
|
||||
QueryParams: nil,
|
||||
NotReturnAllMeta: request.NotReturnAllMeta,
|
||||
ConsistencyLevel: request.ConsistencyLevel,
|
||||
UseDefaultConsistency: request.UseDefaultConsistency,
|
||||
}
|
||||
|
||||
queryResults, _ := node.Query(ctx, queryRequest)
|
||||
|
||||
err = merr.Error(queryResults.GetStatus())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var vectorFieldsData *schemapb.FieldData
|
||||
for _, fieldsData := range queryResults.GetFieldsData() {
|
||||
if fieldsData.GetFieldName() == annsField {
|
||||
vectorFieldsData = fieldsData
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(vectorFieldsData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return placeholderGroupBytes, nil
|
||||
}
|
||||
|
||||
// Flush notify data nodes to persist the data of collection.
|
||||
func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) {
|
||||
resp := &milvuspb.FlushResponse{
|
||||
|
@ -2567,11 +2775,15 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
receiveSize := proto.Size(request)
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel, request.GetCollectionName()).Add(float64(receiveSize))
|
||||
metrics.QueryLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(receiveSize))
|
||||
|
||||
metrics.ProxyReceivedNQ.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel, request.GetCollectionName()).Add(float64(1))
|
||||
metrics.SearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(1))
|
||||
|
||||
rateCol.Add(internalpb.RateType_DQLQuery.String(), 1)
|
||||
|
||||
|
@ -2602,14 +2814,18 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
|
||||
method := "Query"
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.TotalLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.TotalLabel,
|
||||
).Inc()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Strings("partitions", request.PartitionNames))
|
||||
zap.Strings("partitions", request.PartitionNames),
|
||||
)
|
||||
|
||||
defer func() {
|
||||
span := tr.ElapseSpan()
|
||||
|
@ -2629,15 +2845,20 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
zap.String("expr", request.Expr),
|
||||
zap.Strings("OutputFields", request.OutputFields),
|
||||
zap.Uint64("travel_timestamp", request.TravelTimestamp),
|
||||
zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp))
|
||||
zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp),
|
||||
)
|
||||
|
||||
if err := node.sched.dqQueue.Enqueue(qt); err != nil {
|
||||
log.Warn(
|
||||
rpcFailedToEnqueue(method),
|
||||
zap.Error(err))
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.AbandonLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.AbandonLabel,
|
||||
).Inc()
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: merr.Status(err),
|
||||
|
@ -2660,21 +2881,34 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
}, nil
|
||||
}
|
||||
span := tr.CtxRecord(ctx, "wait query result")
|
||||
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel).Observe(float64(span.Milliseconds()))
|
||||
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel,
|
||||
).Observe(float64(span.Milliseconds()))
|
||||
|
||||
log.Debug(rpcDone(method))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
method,
|
||||
metrics.SuccessLabel,
|
||||
).Inc()
|
||||
|
||||
metrics.ProxySQLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
metrics.ProxyCollectionSQLatency.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel,
|
||||
request.CollectionName,
|
||||
).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
metrics.ProxySQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyCollectionSQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.QueryLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
sentSize := proto.Size(qt.result)
|
||||
rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize))
|
||||
metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize))
|
||||
|
||||
return qt.result, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,10 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
@ -1005,6 +1008,7 @@ func TestProxy(t *testing.T) {
|
|||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
var insertedIds []int64
|
||||
wg.Add(1)
|
||||
t.Run("insert", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
@ -1016,6 +1020,13 @@ func TestProxy(t *testing.T) {
|
|||
assert.Equal(t, rowNum, len(resp.SuccIndex))
|
||||
assert.Equal(t, 0, len(resp.ErrIndex))
|
||||
assert.Equal(t, int64(rowNum), resp.InsertCnt)
|
||||
|
||||
switch field := resp.GetIDs().GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
insertedIds = field.IntId.GetData()
|
||||
default:
|
||||
t.Fatalf("Unexpected ID type")
|
||||
}
|
||||
})
|
||||
|
||||
// TODO(dragondriver): proxy.Delete()
|
||||
|
@ -1328,6 +1339,135 @@ func TestProxy(t *testing.T) {
|
|||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
})
|
||||
|
||||
nprobe := 10
|
||||
topk := 10
|
||||
roundDecimal := 6
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
constructVectorsPlaceholderGroup := func() *commonpb.PlaceholderGroup {
|
||||
values := make([][]byte, 0, nq)
|
||||
for i := 0; i < nq; i++ {
|
||||
bs := make([]byte, 0, dim*4)
|
||||
for j := 0; j < dim; j++ {
|
||||
var buffer bytes.Buffer
|
||||
f := rand.Float32()
|
||||
err := binary.Write(&buffer, common.Endian, f)
|
||||
assert.NoError(t, err)
|
||||
bs = append(bs, buffer.Bytes()...)
|
||||
}
|
||||
values = append(values, bs)
|
||||
}
|
||||
|
||||
return &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: values,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
constructSearchRequest := func() *milvuspb.SearchRequest {
|
||||
plg := constructVectorsPlaceholderGroup()
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params := make(map[string]string)
|
||||
params["nprobe"] = strconv.Itoa(nprobe)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
searchParams := []*commonpb.KeyValuePair{
|
||||
{Key: MetricTypeKey, Value: metric.L2},
|
||||
{Key: SearchParamsKey, Value: string(b)},
|
||||
{Key: AnnsFieldKey, Value: floatVecField},
|
||||
{Key: TopKKey, Value: strconv.Itoa(topk)},
|
||||
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
|
||||
}
|
||||
|
||||
return &milvuspb.SearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Dsl: expr,
|
||||
PlaceholderGroup: plgBs,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: nil,
|
||||
SearchParams: searchParams,
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
SearchByPrimaryKeys: false,
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("search", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := constructSearchRequest()
|
||||
|
||||
resp, err := proxy.Search(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup {
|
||||
expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0])
|
||||
exprBytes := []byte(expr)
|
||||
|
||||
return &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_None,
|
||||
Values: [][]byte{exprBytes},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
constructSearchByPksRequest := func() *milvuspb.SearchRequest {
|
||||
plg := constructPrimaryKeysPlaceholderGroup()
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
params := make(map[string]string)
|
||||
params["nprobe"] = strconv.Itoa(nprobe)
|
||||
b, err := json.Marshal(params)
|
||||
assert.NoError(t, err)
|
||||
searchParams := []*commonpb.KeyValuePair{
|
||||
{Key: MetricTypeKey, Value: metric.L2},
|
||||
{Key: SearchParamsKey, Value: string(b)},
|
||||
{Key: AnnsFieldKey, Value: floatVecField},
|
||||
{Key: TopKKey, Value: strconv.Itoa(topk)},
|
||||
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
|
||||
}
|
||||
|
||||
return &milvuspb.SearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Dsl: "",
|
||||
PlaceholderGroup: plgBs,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: nil,
|
||||
SearchParams: searchParams,
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
SearchByPrimaryKeys: true,
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("search by primary keys", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := constructSearchByPksRequest()
|
||||
resp, err := proxy.Search(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
// nprobe := 10
|
||||
// topk := 10
|
||||
// roundDecimal := 6
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
package funcutil
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
|
||||
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
placeholderGroup := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
|
||||
}
|
||||
|
||||
bytes, _ := proto.Marshal(placeholderGroup)
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.PlaceholderValue, error) {
|
||||
switch fieldData.Type {
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectors := fieldData.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector)
|
||||
if !ok {
|
||||
return nil, errors.New("vector data is not schemapb.VectorField_FloatVector")
|
||||
}
|
||||
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: flattenedFloatVectorsToByteVectors(x.FloatVector.Data, int(vectors.Dim)),
|
||||
}
|
||||
return placeholderValue, nil
|
||||
case schemapb.DataType_BinaryVector:
|
||||
vectors := fieldData.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector)
|
||||
if !ok {
|
||||
return nil, errors.New("vector data is not schemapb.VectorField_BinaryVector")
|
||||
}
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_BinaryVector,
|
||||
Values: flattenedByteVectorsToByteVectors(x.BinaryVector, int(vectors.Dim)),
|
||||
}
|
||||
return placeholderValue, nil
|
||||
case schemapb.DataType_Float16Vector:
|
||||
vectors := fieldData.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector)
|
||||
if !ok {
|
||||
return nil, errors.New("vector data is not schemapb.VectorField_Float16Vector")
|
||||
}
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_Float16Vector,
|
||||
Values: flattenedFloat16VectorsToByteVectors(x.Float16Vector, int(vectors.Dim)),
|
||||
}
|
||||
return placeholderValue, nil
|
||||
default:
|
||||
return nil, errors.New("field is not a vector field")
|
||||
}
|
||||
}
|
||||
|
||||
func flattenedFloatVectorsToByteVectors(flattenedVectors []float32, dimension int) [][]byte {
|
||||
floatVectors := flattenedFloatVectorsToFloatVectors(flattenedVectors, dimension)
|
||||
result := make([][]byte, 0)
|
||||
for _, floatVector := range floatVectors {
|
||||
result = append(result, floatVectorToByteVector(floatVector))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func flattenedFloatVectorsToFloatVectors(flattenedVectors []float32, dimension int) [][]float32 {
|
||||
result := make([][]float32, 0)
|
||||
for i := 0; i < len(flattenedVectors); i += dimension {
|
||||
result = append(result, flattenedVectors[i:i+dimension])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func floatVectorToByteVector(vector []float32) []byte {
|
||||
data := make([]byte, 0, 4*len(vector)) // float32 occupies 4 bytes
|
||||
buf := make([]byte, 4)
|
||||
for _, f := range vector {
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
|
||||
data = append(data, buf...)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func flattenedByteVectorsToByteVectors(flattenedVectors []byte, dimension int) [][]byte {
|
||||
result := make([][]byte, 0)
|
||||
for i := 0; i < len(flattenedVectors); i += dimension {
|
||||
result = append(result, flattenedVectors[i:i+dimension])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func flattenedFloat16VectorsToByteVectors(flattenedVectors []byte, dimension int) [][]byte {
|
||||
result := make([][]byte, 0)
|
||||
|
||||
vectorBytes := 2 * dimension
|
||||
|
||||
for i := 0; i < len(flattenedVectors); i += vectorBytes {
|
||||
result = append(result, flattenedVectors[i:i+vectorBytes])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package funcutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_flattenedByteVectorsToByteVectors(t *testing.T) {
|
||||
flattenedVectors := []byte{0, 1, 2, 3, 4, 5}
|
||||
dimension := 3
|
||||
|
||||
actual := flattenedByteVectorsToByteVectors(flattenedVectors, dimension)
|
||||
expected := [][]byte{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func Test_flattenedFloat16VectorsToByteVectors(t *testing.T) {
|
||||
flattenedVectors := []byte{0, 1, 2, 3, 4, 5, 6, 7}
|
||||
dimension := 2
|
||||
|
||||
actual := flattenedFloat16VectorsToByteVectors(flattenedVectors, dimension)
|
||||
expected := [][]byte{
|
||||
{0, 1, 2, 3},
|
||||
{4, 5, 6, 7},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
Loading…
Reference in New Issue