milvus/internal/util/function/function_executor.go

284 lines
9.1 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"context"
"fmt"
"strconv"
"sync"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
)
type Runner interface {
GetSchema() *schemapb.FunctionSchema
GetOutputFields() []*schemapb.FieldSchema
GetCollectionName() string
GetFunctionTypeName() string
GetFunctionName() string
GetFunctionProvider() string
Check() error
MaxBatch() int
ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
}
type FunctionExecutor struct {
runners map[int64]Runner
}
func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (Runner, error) {
switch schema.GetType() {
case schemapb.FunctionType_BM25: // ignore bm25 function
return nil, nil
case schemapb.FunctionType_TextEmbedding:
f, err := NewTextEmbeddingFunction(coll, schema)
if err != nil {
return nil, err
}
return f, nil
default:
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
}
}
// Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here.
func ValidateFunctions(schema *schemapb.CollectionSchema) error {
for _, fSchema := range schema.Functions {
f, err := createFunction(schema, fSchema)
if err != nil {
return err
}
// ignore bm25 function
if f == nil {
continue
}
if err := f.Check(); err != nil {
return fmt.Errorf("Check function [%s:%s] failed, the err is: %v", fSchema.Name, fSchema.GetType().String(), err)
}
}
return nil
}
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
executor := &FunctionExecutor{
runners: make(map[int64]Runner),
}
for _, fSchema := range schema.Functions {
runner, err := createFunction(schema, fSchema)
if err != nil {
return nil, err
}
if runner != nil {
executor.runners[fSchema.GetOutputFieldIds()[0]] = runner
}
}
return executor, nil
}
func (executor *FunctionExecutor) processSingleFunction(ctx context.Context, runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames()))
for _, name := range runner.GetSchema().GetInputFieldNames() {
for _, field := range msg.FieldsData {
if field.GetFieldName() == name {
inputs = append(inputs, field)
}
}
}
if len(inputs) != len(runner.GetSchema().InputFieldIds) {
return nil, errors.New("Input field not found")
}
tr := timerecord.NewTimeRecorder("function ProcessInsert")
outputs, err := runner.ProcessInsert(ctx, inputs)
if err != nil {
return nil, err
}
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
tr.CtxElapse(ctx, "function ProcessInsert done")
return outputs, nil
}
func (executor *FunctionExecutor) ProcessInsert(ctx context.Context, msg *msgstream.InsertMsg) error {
numRows := msg.NumRows
for _, runner := range executor.runners {
if numRows > uint64(runner.MaxBatch()) {
return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch())
}
}
outputs := make(chan []*schemapb.FieldData, len(executor.runners))
errChan := make(chan error, len(executor.runners))
var wg sync.WaitGroup
for _, runner := range executor.runners {
wg.Add(1)
go func(runner Runner) {
defer wg.Done()
data, err := executor.processSingleFunction(ctx, runner, msg)
if err != nil {
errChan <- err
return
}
outputs <- data
}(runner)
}
wg.Wait()
close(errChan)
close(outputs)
// Collect all errors
var errs []error
for err := range errChan {
errs = append(errs, err)
}
if len(errs) > 0 {
return fmt.Errorf("%v", errs)
}
for output := range outputs {
msg.FieldsData = append(msg.FieldsData, output...)
}
return nil
}
func (executor *FunctionExecutor) processSingleSearch(ctx context.Context, runner Runner, placeholderGroup []byte) ([]byte, error) {
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroup, pb)
if len(pb.Placeholders) != 1 {
return nil, merr.WrapErrParameterInvalidMsg("No placeholders founded")
}
if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar {
return placeholderGroup, nil
}
tr := timerecord.NewTimeRecorder("function ProcessSearch")
res, err := runner.ProcessSearch(ctx, pb)
if err != nil {
return nil, err
}
metrics.ProxyFunctionlatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), runner.GetCollectionName(), runner.GetFunctionTypeName(), runner.GetFunctionProvider(), runner.GetFunctionName()).Observe(float64(tr.RecordSpan().Milliseconds()))
tr.CtxElapse(ctx, "function ProcessSearch done")
return proto.Marshal(res)
}
func (executor *FunctionExecutor) prcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
runner, exist := executor.runners[req.FieldId]
if !exist {
return fmt.Errorf("Can not found function in field %d", req.FieldId)
}
if req.Nq > int64(runner.MaxBatch()) {
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch())
}
if newHolder, err := executor.processSingleSearch(ctx, runner, req.GetPlaceholderGroup()); err != nil {
return err
} else {
req.PlaceholderGroup = newHolder
}
return nil
}
func (executor *FunctionExecutor) prcessAdvanceSearch(ctx context.Context, req *internalpb.SearchRequest) error {
outputs := make(chan map[int64][]byte, len(req.GetSubReqs()))
errChan := make(chan error, len(req.GetSubReqs()))
var wg sync.WaitGroup
for idx, sub := range req.GetSubReqs() {
if runner, exist := executor.runners[sub.FieldId]; exist {
if sub.Nq > int64(runner.MaxBatch()) {
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch())
}
wg.Add(1)
go func(runner Runner, idx int64, placeholderGroup []byte) {
defer wg.Done()
if newHolder, err := executor.processSingleSearch(ctx, runner, placeholderGroup); err != nil {
errChan <- err
} else {
outputs <- map[int64][]byte{idx: newHolder}
}
}(runner, int64(idx), sub.GetPlaceholderGroup())
}
}
wg.Wait()
close(errChan)
close(outputs)
for err := range errChan {
return err
}
for output := range outputs {
for idx, holder := range output {
req.SubReqs[idx].PlaceholderGroup = holder
}
}
return nil
}
func (executor *FunctionExecutor) ProcessSearch(ctx context.Context, req *internalpb.SearchRequest) error {
if !req.IsAdvanced {
return executor.prcessSearch(ctx, req)
}
return executor.prcessAdvanceSearch(ctx, req)
}
func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {
inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds))
for idx, id := range runner.GetSchema().InputFieldIds {
field, exist := data.Data[id]
if !exist {
return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx])
}
inputs = append(inputs, field)
}
outputs, err := runner.ProcessBulkInsert(inputs)
if err != nil {
return nil, err
}
return outputs, nil
}
func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error {
// Since concurrency has already been used in the outer layer, only a serial logic access model is used here.
for _, runner := range executor.runners {
output, err := executor.processSingleBulkInsert(runner, data)
if err != nil {
return nil
}
for k, v := range output {
data.Data[k] = v
}
}
return nil
}