mirror of https://github.com/milvus-io/milvus.git
284 lines
9.1 KiB
Go
284 lines
9.1 KiB
Go
/*
|
|
* # Licensed to the LF AI & Data foundation under one
|
|
* # or more contributor license agreements. See the NOTICE file
|
|
* # distributed with this work for additional information
|
|
* # regarding copyright ownership. The ASF licenses this file
|
|
* # to you under the Apache License, Version 2.0 (the
|
|
* # "License"); you may not use this file except in compliance
|
|
* # with the License. You may obtain a copy of the License at
|
|
* #
|
|
* # http://www.apache.org/licenses/LICENSE-2.0
|
|
* #
|
|
* # Unless required by applicable law or agreed to in writing, software
|
|
* # distributed under the License is distributed on an "AS IS" BASIS,
|
|
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* # See the License for the specific language governing permissions and
|
|
* # limitations under the License.
|
|
*/
|
|
|
|
package function
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/storage"
|
|
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
|
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
|
)
|
|
|
|
type Runner interface {
|
|
GetSchema() *schemapb.FunctionSchema
|
|
GetOutputFields() []*schemapb.FieldSchema
|
|
GetCollectionName() string
|
|
GetFunctionTypeName() string
|
|
GetFunctionName() string
|
|
GetFunctionProvider() string
|
|
Check() error
|
|
|
|
MaxBatch() int
|
|
ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
|
|
ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
|
|
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
|
|
}
|
|
|
|
type FunctionExecutor struct {
|
|
runners map[int64]Runner
|
|
}
|
|
|
|
func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (Runner, error) {
|
|
switch schema.GetType() {
|
|
case schemapb.FunctionType_BM25: // ignore bm25 function
|
|
return nil, nil
|
|
case schemapb.FunctionType_TextEmbedding:
|
|
f, err := NewTextEmbeddingFunction(coll, schema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return f, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
|
|
}
|
|
}
|
|
|
|
// Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here.
|
|
func ValidateFunctions(schema *schemapb.CollectionSchema) error {
|
|
for _, fSchema := range schema.Functions {
|
|
f, err := createFunction(schema, fSchema)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// ignore bm25 function
|
|
if f == nil {
|
|
continue
|
|
}
|
|
if err := f.Check(); err != nil {
|
|
return fmt.Errorf("Check function [%s:%s] failed, the err is: %v", fSchema.Name, fSchema.GetType().String(), err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
|
|
executor := &FunctionExecutor{
|
|
runners: make(map[int64]Runner),
|
|
}
|
|
for _, fSchema := range schema.Functions {
|
|
runner, err := createFunction(schema, fSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if runner != nil {
|
|
executor.runners[fSchema.GetOutputFieldIds()[0]] = runner
|
|
}
|
|
}
|
|
return executor, nil
|
|
}
|
|
|
|
func (executor *FunctionExecutor) processSingleFunction(ctx context.Context, runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
|
|
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames()))
|
|
for _, name := range runner.GetSchema().GetInputFieldNames() {
|
|
for _, field := range msg.FieldsData {
|
|
if field.GetFieldName() == name {
|
|
inputs = append(inputs, field)
|
|
}
|
|
}
|
|
}
|
|
if len(inputs) != len(runner.GetSchema().InputFieldIds) {
|
|
return nil, errors.New("Input field not found")
|
|
}
|
|
|
|
tr := timerecord.NewTimeRecorder("function ProcessInsert")
|
|
outputs, err := runner.ProcessInsert(ctx, inputs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
|
|
tr.CtxElapse(ctx, "function ProcessInsert done")
|
|
return outputs, nil
|
|
}
|
|
|
|
func (executor *FunctionExecutor) ProcessInsert(ctx context.Context, msg *msgstream.InsertMsg) error {
|
|
numRows := msg.NumRows
|
|
for _, runner := range executor.runners {
|
|
if numRows > uint64(runner.MaxBatch()) {
|
|
return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch())
|
|
}
|
|
}
|
|
|
|
outputs := make(chan []*schemapb.FieldData, len(executor.runners))
|
|
errChan := make(chan error, len(executor.runners))
|
|
var wg sync.WaitGroup
|
|
for _, runner := range executor.runners {
|
|
wg.Add(1)
|
|
go func(runner Runner) {
|
|
defer wg.Done()
|
|
data, err := executor.processSingleFunction(ctx, runner, msg)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
outputs <- data
|
|
}(runner)
|
|
}
|
|
wg.Wait()
|
|
close(errChan)
|
|
close(outputs)
|
|
|
|
// Collect all errors
|
|
var errs []error
|
|
for err := range errChan {
|
|
errs = append(errs, err)
|
|
}
|
|
if len(errs) > 0 {
|
|
return fmt.Errorf("%v", errs)
|
|
}
|
|
|
|
for output := range outputs {
|
|
msg.FieldsData = append(msg.FieldsData, output...)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (executor *FunctionExecutor) processSingleSearch(ctx context.Context, runner Runner, placeholderGroup []byte) ([]byte, error) {
|
|
pb := &commonpb.PlaceholderGroup{}
|
|
proto.Unmarshal(placeholderGroup, pb)
|
|
if len(pb.Placeholders) != 1 {
|
|
return nil, merr.WrapErrParameterInvalidMsg("No placeholders founded")
|
|
}
|
|
if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar {
|
|
return placeholderGroup, nil
|
|
}
|
|
|
|
tr := timerecord.NewTimeRecorder("function ProcessSearch")
|
|
res, err := runner.ProcessSearch(ctx, pb)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
|
|
tr.CtxElapse(ctx, "function ProcessSearch done")
|
|
return proto.Marshal(res)
|
|
}
|
|
|
|
func (executor *FunctionExecutor) prcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
|
|
runner, exist := executor.runners[req.FieldId]
|
|
if !exist {
|
|
return fmt.Errorf("Can not found function in field %d", req.FieldId)
|
|
}
|
|
if req.Nq > int64(runner.MaxBatch()) {
|
|
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch())
|
|
}
|
|
if newHolder, err := executor.processSingleSearch(ctx, runner, req.GetPlaceholderGroup()); err != nil {
|
|
return err
|
|
} else {
|
|
req.PlaceholderGroup = newHolder
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (executor *FunctionExecutor) prcessAdvanceSearch(ctx context.Context, req *internalpb.SearchRequest) error {
|
|
outputs := make(chan map[int64][]byte, len(req.GetSubReqs()))
|
|
errChan := make(chan error, len(req.GetSubReqs()))
|
|
var wg sync.WaitGroup
|
|
for idx, sub := range req.GetSubReqs() {
|
|
if runner, exist := executor.runners[sub.FieldId]; exist {
|
|
if sub.Nq > int64(runner.MaxBatch()) {
|
|
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch())
|
|
}
|
|
wg.Add(1)
|
|
go func(runner Runner, idx int64, placeholderGroup []byte) {
|
|
defer wg.Done()
|
|
if newHolder, err := executor.processSingleSearch(ctx, runner, placeholderGroup); err != nil {
|
|
errChan <- err
|
|
} else {
|
|
outputs <- map[int64][]byte{idx: newHolder}
|
|
}
|
|
}(runner, int64(idx), sub.GetPlaceholderGroup())
|
|
}
|
|
}
|
|
wg.Wait()
|
|
close(errChan)
|
|
close(outputs)
|
|
for err := range errChan {
|
|
return err
|
|
}
|
|
|
|
for output := range outputs {
|
|
for idx, holder := range output {
|
|
req.SubReqs[idx].PlaceholderGroup = holder
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (executor *FunctionExecutor) ProcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
|
|
if !req.IsAdvanced {
|
|
return executor.prcessSearch(ctx, req)
|
|
}
|
|
return executor.prcessAdvanceSearch(ctx, req)
|
|
}
|
|
|
|
func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {
|
|
inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds))
|
|
for idx, id := range runner.GetSchema().InputFieldIds {
|
|
field, exist := data.Data[id]
|
|
if !exist {
|
|
return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx])
|
|
}
|
|
inputs = append(inputs, field)
|
|
}
|
|
|
|
outputs, err := runner.ProcessBulkInsert(inputs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return outputs, nil
|
|
}
|
|
|
|
func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error {
|
|
// Since concurrency has already been used in the outer layer, only a serial logic access model is used here.
|
|
for _, runner := range executor.runners {
|
|
output, err := executor.processSingleBulkInsert(runner, data)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
for k, v := range output {
|
|
data.Data[k] = v
|
|
}
|
|
}
|
|
return nil
|
|
}
|