milvus/internal/util/function/text_embedding_function.go

325 lines
11 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"context"
"fmt"
"reflect"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/credentials"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
const (
Provider string = "provider"
)
const (
openAIProvider string = "openai"
azureOpenAIProvider string = "azure_openai"
aliDashScopeProvider string = "dashscope"
bedrockProvider string = "bedrock"
vertexAIProvider string = "vertexai"
voyageAIProvider string = "voyageai"
cohereProvider string = "cohere"
siliconflowProvider string = "siliconflow"
teiProvider string = "tei"
)
func hasEmptyString(texts []string) bool {
for _, text := range texts {
if text == "" {
return true
}
}
return false
}
func TextEmbeddingOutputsCheck(fields []*schemapb.FieldSchema) error {
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_FloatVector && fields[0].DataType != schemapb.DataType_Int8Vector) {
return errors.New("TextEmbedding function output field must be a FloatVector or Int8Vector field")
}
return nil
}
// Text embedding for retrieval task
type textEmbeddingProvider interface {
MaxBatch() int
CallEmbedding(texts []string, mode TextEmbeddingMode) (any, error)
FieldDim() int64
}
type TextEmbeddingFunction struct {
FunctionBase
embProvider textEmbeddingProvider
}
func isValidInputDataType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_VarChar || dataType == schemapb.DataType_Text
}
func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmbeddingFunction, error) {
if len(functionSchema.GetOutputFieldNames()) != 1 {
return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames()))
}
base, err := NewFunctionBase(coll, functionSchema)
if err != nil {
return nil, err
}
if err := TextEmbeddingOutputsCheck(base.outputFields); err != nil {
return nil, err
}
var embP textEmbeddingProvider
var newProviderErr error
conf := paramtable.Get().FunctionCfg.GetTextEmbeddingProviderConfig(base.provider)
credentials := credentials.NewCredentials(paramtable.Get().CredentialCfg.GetCredentials())
switch base.provider {
case openAIProvider:
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case azureOpenAIProvider:
embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case bedrockProvider:
embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf, credentials)
case aliDashScopeProvider:
embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case vertexAIProvider:
embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil, conf, credentials)
case voyageAIProvider:
embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case cohereProvider:
embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case siliconflowProvider:
embP, newProviderErr = NewSiliconflowEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
case teiProvider:
embP, newProviderErr = NewTEIEmbeddingProvider(base.outputFields[0], functionSchema, conf, credentials)
default:
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s, %s, %s]", base.provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider, siliconflowProvider, teiProvider)
}
if newProviderErr != nil {
return nil, newProviderErr
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
}
func (runner *TextEmbeddingFunction) Check() error {
embds, err := runner.embProvider.CallEmbedding([]string{"check"}, InsertMode)
if err != nil {
return err
}
dim := 0
switch embds := embds.(type) {
case [][]float32:
dim = len(embds[0])
case [][]int8:
dim = len(embds[0])
default:
return fmt.Errorf("Unsupport embedding type: %s", reflect.TypeOf(embds).String())
}
if dim != int(runner.embProvider.FieldDim()) {
return fmt.Errorf("The dim set in the schema is inconsistent with the dim of the model, dim in schema is %d, dim of model is %d", runner.embProvider.FieldDim(), dim)
}
return nil
}
func (runner *TextEmbeddingFunction) MaxBatch() int {
return runner.embProvider.MaxBatch()
}
func (runner *TextEmbeddingFunction) GetCollectionName() string {
return runner.collectionName
}
func (runner *TextEmbeddingFunction) GetFunctionProvider() string {
return runner.provider
}
func (runner *TextEmbeddingFunction) GetFunctionTypeName() string {
return runner.functionTypeName
}
func (runner *TextEmbeddingFunction) GetFunctionName() string {
return runner.functionName
}
func (runner *TextEmbeddingFunction) packToFieldData(embds any) ([]*schemapb.FieldData, error) {
var outputField schemapb.FieldData
outputField.FieldId = runner.GetOutputFields()[0].FieldID
outputField.FieldName = runner.GetOutputFields()[0].Name
outputField.Type = runner.GetOutputFields()[0].DataType
outputField.IsDynamic = runner.GetOutputFields()[0].IsDynamic
switch embds := embds.(type) {
case [][]float32:
data := make([]float32, 0, len(embds)*int(runner.embProvider.FieldDim()))
for _, emb := range embds {
data = append(data, emb...)
}
outputField.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: data,
},
},
Dim: runner.embProvider.FieldDim(),
},
}
case [][]int8:
data := make([]byte, 0, len(embds)*int(runner.embProvider.FieldDim()))
for _, emb := range embds {
for _, v := range emb {
data = append(data, byte(v))
}
}
outputField.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Int8Vector{
Int8Vector: data,
},
Dim: runner.embProvider.FieldDim(),
},
}
}
return []*schemapb.FieldData{&outputField}, nil
}
func (runner *TextEmbeddingFunction) ProcessInsert(ctx context.Context, inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs))
}
if !isValidInputDataType(inputs[0].Type) {
return nil, fmt.Errorf("Text embedding only supports varchar or text field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)])
}
texts := inputs[0].GetScalars().GetStringData().GetData()
if texts == nil {
return nil, errors.New("Input texts is empty")
}
// make sure all texts are not empty
if hasEmptyString(texts) {
return nil, errors.New("There is an empty string in the input data, TextEmbedding function does not support empty text")
}
numRows := len(texts)
if numRows > runner.MaxBatch() {
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
}
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
if err != nil {
return nil, err
}
return runner.packToFieldData(embds)
}
func (runner *TextEmbeddingFunction) ProcessSearch(ctx context.Context, placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally
numRows := len(texts)
if numRows > runner.MaxBatch() {
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
}
// make sure all texts are not empty
if hasEmptyString(texts) {
return nil, errors.New("There is an empty string in the queries, TextEmbedding function does not support empty text")
}
embds, err := runner.embProvider.CallEmbedding(texts, SearchMode)
if err != nil {
return nil, err
}
if runner.GetOutputFields()[0].DataType == schemapb.DataType_FloatVector {
return funcutil.Float32VectorsToPlaceholderGroup(embds.([][]float32)), nil
} else if runner.GetOutputFields()[0].DataType == schemapb.DataType_Int8Vector {
return funcutil.Int8VectorsToPlaceholderGroup(embds.([][]int8)), nil
}
return nil, fmt.Errorf("Text embedding function doesn't support % vector", schemapb.DataType_name[int32(runner.GetOutputFields()[0].DataType)])
}
func (runner *TextEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs))
}
if !isValidInputDataType(inputs[0].GetDataType()) {
return nil, fmt.Errorf("TextEmbedding function only supports varchar or text field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].GetDataType())])
}
texts, ok := inputs[0].GetDataRows().([]string)
if !ok {
return nil, errors.New("Input texts is empty")
}
// make sure all texts are not empty
// In storage.FieldData, null is also stored as an empty string
if hasEmptyString(texts) {
return nil, errors.New("There is an empty string in the input data, TextEmbedding function does not support empty text")
}
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
if err != nil {
return nil, err
}
switch embds := embds.(type) {
case [][]float32:
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
for _, emb := range embds {
data = append(data, emb...)
}
field := &storage.FloatVectorFieldData{
Data: data,
Dim: int(runner.embProvider.FieldDim()),
}
return map[storage.FieldID]storage.FieldData{
runner.outputFields[0].FieldID: field,
}, nil
case [][]int8:
data := make([]int8, 0, len(texts)*int(runner.embProvider.FieldDim()))
for _, emb := range embds {
data = append(data, emb...)
}
field := &storage.Int8VectorFieldData{
Data: data,
Dim: int(runner.embProvider.FieldDim()),
}
return map[storage.FieldID]storage.FieldData{
runner.outputFields[0].FieldID: field,
}, nil
}
return nil, errors.New("Unknow embedding type")
}