feat: Add Text Embedding Function (#36366)

https://github.com/milvus-io/milvus/issues/35856

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
pull/39580/head
junjiejiangjjj 2025-01-24 14:23:06 +08:00 committed by GitHub
parent 47d280d974
commit 16cbdfb3b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 6232 additions and 10 deletions

15
go.mod
View File

@ -59,6 +59,10 @@ require (
require (
cloud.google.com/go/storage v1.43.0
github.com/antlr4-go/antlr/v4 v4.13.1
github.com/aws/aws-sdk-go-v2 v1.32.6
github.com/aws/aws-sdk-go-v2/config v1.28.6
github.com/aws/aws-sdk-go-v2/credentials v1.17.47
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0
github.com/bits-and-blooms/bitset v1.10.0
github.com/bytedance/sonic v1.12.2
github.com/cenkalti/backoff/v4 v4.2.1
@ -101,6 +105,17 @@ require (
github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7 // indirect
github.com/apache/thrift v0.18.1 // indirect
github.com/ardielle/ardielle-go v1.5.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect
github.com/aws/smithy-go v1.22.1 // indirect
github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bytedance/sonic/loader v0.2.0 // indirect

30
go.sum
View File

@ -120,6 +120,36 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4=
github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc=
github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo=
github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko=
github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw=
github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 h1:mfV5tcLXeRLbiyI4EHoHWH1sIU7JvbfXVvymUCIgZEo=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0/go.mod h1:YSSgYnasDKm5OjU3bOPkaz+2PFO6WjEQGIA6KQNsR3Q=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug=
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw=
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8=
github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro=
github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b h1:5JgaFtHFRnOPReItxvhMDXbvuBkjSWE+9glJyF466yw=

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/importutilv2"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/common"
@ -435,6 +436,107 @@ func (s *SchedulerSuite) TestScheduler_ImportFile() {
s.NoError(err)
}
func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) (*conc.Future[struct{}], error) {
future := conc.Go(func() (struct{}, error) {
return struct{}{}, nil
})
return future, nil
})
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "128"},
},
},
{
FieldID: 101,
Name: "vec",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
},
{
FieldID: 102,
Name: "int64",
DataType: schemapb.DataType_Int64,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{100},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{101},
OutputFieldNames: []string{"vec"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"},
},
},
},
}
var once sync.Once
data, err := testutil.CreateInsertData(schema, s.numRows)
s.NoError(err)
s.reader = importutilv2.NewMockReader(s.T())
s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) {
var res *storage.InsertData
once.Do(func() {
res = data
})
if res != nil {
return res, nil
}
return nil, io.EOF
})
importReq := &datapb.ImportRequest{
JobID: 10,
TaskID: 11,
CollectionID: 12,
PartitionIDs: []int64{13},
Vchannels: []string{"v0"},
Schema: schema,
Files: []*internalpb.ImportFile{
{
Paths: []string{"dummy.json"},
},
},
Ts: 1000,
IDRange: &datapb.IDRange{
Begin: 0,
End: int64(s.numRows),
},
RequestSegments: []*datapb.ImportRequestSegment{
{
SegmentID: 14,
PartitionID: 13,
Vchannel: "v0",
},
},
}
importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm)
s.manager.Add(importTask)
err = importTask.(*ImportTask).importFile(s.reader)
s.NoError(err)
}
func TestScheduler(t *testing.T) {
suite.Run(t, new(SchedulerSuite))
}

View File

@ -199,12 +199,39 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData) error {
}
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
if err := RunBm25Function(task, data); err != nil {
return err
}
if err := RunDenseEmbedding(task, data); err != nil {
return err
}
return nil
}
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
schema := task.GetSchema()
if function.HasNonBM25Functions(schema.Functions, []int64{}) {
exec, err := function.NewFunctionExecutor(schema)
if err != nil {
return err
}
if err := exec.ProcessBulkInsert(data); err != nil {
return err
}
}
return nil
}
func RunBm25Function(task *ImportTask, data *storage.InsertData) error {
fns := task.GetSchema().GetFunctions()
for _, fn := range fns {
runner, err := function.NewFunctionRunner(task.GetSchema(), fn)
if err != nil {
return err
}
if runner == nil {
continue
}
inputDatas := make([]any, 0, len(fn.InputFieldIds))
for _, inputFieldID := range fn.InputFieldIds {
inputDatas = append(inputDatas, data.Data[inputFieldID].GetDataRows())

View File

@ -1099,8 +1099,8 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
if vectorField.GetIsFunctionOutput() {
for _, function := range collSchema.Functions {
if function.Type == schemapb.FunctionType_BM25 {
// TODO: currently only BM25 function is supported, thus guarantees one input field to one output field
if function.Type == schemapb.FunctionType_BM25 || function.Type == schemapb.FunctionType_TextEmbedding {
// TODO: currently only BM25 & text embedding function is supported, thus guarantees one input field to one output field
if function.OutputFieldNames[0] == vectorField.Name {
dataType = schemapb.DataType_VarChar
}

View File

@ -67,6 +67,9 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e
if err != nil {
return nil, err
}
if functionRunner == nil {
continue
}
node.functionRunners[tf.GetId()] = functionRunner
}
return node, nil

View File

@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
@ -141,6 +142,16 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
it.schema = schema.CollectionSchema
// Calculate embedding fields
if function.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
exec, err := function.NewFunctionExecutor(schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessInsert(it.insertMsg); err != nil {
return err
}
}
rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs
var rowIDBegin UniqueID

View File

@ -10,7 +10,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils"
@ -308,3 +312,128 @@ func TestMaxInsertSize(t *testing.T) {
assert.ErrorIs(t, err, merr.ErrParameterTooLarge)
})
}
func TestInsertTask_Function(t *testing.T) {
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
FieldName: "text",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"sentence", "sentence"},
},
},
},
},
}
data = append(data, &f)
collectionName := "TestInsertTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestInsertTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
},
},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"},
},
},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: 11198,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
task := insertTask{
ctx: context.Background(),
insertMsg: &BaseInsertTask{
InsertRequest: &msgpb.InsertRequest{
CollectionName: collectionName,
DbName: "hooooooo",
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: data,
NumRows: 2,
},
},
schema: schema,
idAllocator: idAllocator,
}
info := newSchemaInfo(schema)
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(info, nil)
cache.On("GetPartitionInfo",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&partitionInfo{
name: "p1",
partitionID: 10,
createdTimestamp: 10001,
createdUtcTimestamp: 10002,
}, nil)
cache.On("GetCollectionInfo",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&collectionInfo{schema: info}, nil)
cache.On("GetDatabaseInfo",
mock.Anything,
mock.Anything,
).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil)
globalMetaCache = cache
err = task.PreExecute(ctx)
assert.NoError(t, err)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
@ -362,6 +363,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
// fetch search_growing from search param
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
queryFieldIds := []int64{}
for index, subReq := range t.request.GetSubReqs() {
plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), subReq.GetExprTemplateValues())
if err != nil {
@ -383,6 +385,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
}
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
queryFieldIds = append(queryFieldIds, internalSubReq.FieldId)
// set PartitionIDs for sub search
if t.partitionKeyMode {
// isolation has tighter constraint, check first
@ -421,6 +424,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
zap.Stringer("plan", plan)) // may be very large if large term passed.
}
var err error
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIds) {
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
return err
}
}
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
@ -428,7 +442,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
if t.partitionKeyMode {
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
}
var err error
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
if err != nil {
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
@ -499,6 +513,16 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId
t.SearchRequest.GroupSize = queryInfo.GroupSize
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
return err
}
}
log.Debug("proxy init search request",
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.Stringer("plan", plan)) // may be very large if large term passed.

View File

@ -26,6 +26,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -39,6 +40,7 @@ import (
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/proto/internalpb"
@ -499,6 +501,251 @@ func TestSearchTask_PreExecute(t *testing.T) {
})
}
func TestSearchTask_WithFunctions(t *testing.T) {
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
collectionName := "TestSearchTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestSearchTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
},
},
{
FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "func1",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector1"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"},
},
},
{
Name: "func2",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{103},
OutputFieldNames: []string{"vector2"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"},
},
},
},
}
var err error
var (
rc = NewRootCoordMock()
qc = mocks.NewMockQueryCoordClient(t)
ctx = context.TODO()
)
defer rc.Close()
require.NoError(t, err)
mgr := newShardClientMgr()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
err = InitMetaCache(ctx, rc, qc, mgr)
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string, data []string) *searchTask {
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_VarChar,
Values: lo.Map(data, func(str string, _ int) []byte { return []byte(str) }),
}
holder := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
}
holderByte, _ := proto.Marshal(holder)
task := &searchTask{
ctx: ctx,
collectionName: collectionName,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
},
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
Nq: int64(len(data)),
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: "vector1"},
{Key: TopKKey, Value: "10"},
},
PlaceholderGroup: holderByte,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
collectionID := UniqueID(1000)
cache := NewMockCache(t)
info := newSchemaInfo(schema)
cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
globalMetaCache = cache
{
task := getSearchTask(t, collectionName, []string{"sentence"})
err = task.PreExecute(ctx)
assert.NoError(t, err)
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb)
assert.Equal(t, len(pb.Placeholders), 1)
assert.Equal(t, len(pb.Placeholders[0].Values), 1)
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
}
{
task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"})
err = task.PreExecute(ctx)
assert.NoError(t, err)
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb)
assert.Equal(t, len(pb.Placeholders), 1)
assert.Equal(t, len(pb.Placeholders[0].Values), 2)
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
}
// process failed
{
task := getSearchTask(t, collectionName, []string{"sentence"})
task.request.Nq = 10000
err = task.PreExecute(ctx)
assert.Error(t, err)
}
getHybridSearchTask := func(t *testing.T, collName string, data [][]string) *searchTask {
subReqs := []*milvuspb.SubSearchRequest{}
for _, item := range data {
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_VarChar,
Values: lo.Map(item, func(str string, _ int) []byte { return []byte(str) }),
}
holder := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
}
holderByte, _ := proto.Marshal(holder)
subReq := &milvuspb.SubSearchRequest{
PlaceholderGroup: holderByte,
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: "vector1"},
{Key: TopKKey, Value: "10"},
},
Nq: int64(len(item)),
}
subReqs = append(subReqs, subReq)
}
task := &searchTask{
ctx: ctx,
collectionName: collectionName,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
},
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
SubReqs: subReqs,
SearchParams: []*commonpb.KeyValuePair{
{Key: LimitKey, Value: "10"},
},
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
{
task := getHybridSearchTask(t, collectionName, [][]string{
{"sentence1"},
{"sentence2"},
})
err = task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, len(task.SearchRequest.SubReqs), 2)
for _, sub := range task.SearchRequest.SubReqs {
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(sub.PlaceholderGroup, pb)
assert.Equal(t, len(pb.Placeholders), 1)
assert.Equal(t, len(pb.Placeholders[0].Values), 1)
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
}
}
{
task := getHybridSearchTask(t, collectionName, [][]string{
{"sentence1", "sentence1"},
{"sentence2", "sentence2"},
{"sentence3", "sentence3"},
})
err = task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, len(task.SearchRequest.SubReqs), 3)
for _, sub := range task.SearchRequest.SubReqs {
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(sub.PlaceholderGroup, pb)
assert.Equal(t, len(pb.Placeholders), 1)
assert.Equal(t, len(pb.Placeholders[0].Values), 2)
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
}
}
// process failed
{
task := getHybridSearchTask(t, collectionName, [][]string{
{"sentence1", "sentence1"},
{"sentence2", "sentence2"},
{"sentence3", "sentence3"},
})
task.request.SubReqs[0].Nq = 10000
err = task.PreExecute(ctx)
assert.Error(t, err)
}
}
func getQueryCoord() *mocks.MockQueryCoord {
qc := &mocks.MockQueryCoord{}
qc.EXPECT().Start().Return(nil)

View File

@ -20,6 +20,7 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"math/rand"
"strconv"
"testing"
@ -1022,6 +1023,47 @@ func TestCreateCollectionTask(t *testing.T) {
err = task2.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("collection with embedding function ", func(t *testing.T) {
fmt.Println(schema)
schema.Functions = []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{varCharField},
OutputFieldNames: []string{floatVecField},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
},
},
}
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task2 := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = task2.OnEnqueue()
assert.NoError(t, err)
err = task2.PreExecute(ctx)
assert.NoError(t, err)
})
}
func TestHasCollectionTask(t *testing.T) {

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
@ -152,6 +153,16 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err
}
// Calculate embedding fields
if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
if err != nil {
return err
}
if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil {
return err
}
}
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs
tr := timerecord.NewTimeRecorder("applyPK")

View File

@ -27,8 +27,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/testutils"
)
@ -360,3 +365,133 @@ func TestUpsertTaskForReplicate(t *testing.T) {
assert.Error(t, err)
})
}
func TestUpsertTask_Function(t *testing.T) {
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
data := []*schemapb.FieldData{}
f1 := schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldId: 100,
FieldName: "id",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
},
}
data = append(data, &f1)
f2 := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
FieldName: "text",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"sentence", "sentence"},
},
},
},
},
}
data = append(data, &f2)
collectionName := "TestUpsertTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestUpsertTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
},
},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: ts.URL},
{Key: "dim", Value: "4"},
},
},
},
}
info := newSchemaInfo(schema)
collectionID := UniqueID(0)
cache := NewMockCache(t)
globalMetaCache = cache
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: collectionID,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
task := upsertTask{
ctx: context.Background(),
req: &milvuspb.UpsertRequest{
CollectionName: collectionName,
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
),
CollectionName: collectionName,
DbName: "hooooooo",
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: data,
NumRows: 2,
PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(),
},
},
},
idAllocator: idAllocator,
schema: info,
result: &milvuspb.MutationResult{},
}
err = task.insertPreExecute(ctx)
assert.NoError(t, err)
// process failed
{
oldRows := task.upsertMsg.InsertMsg.InsertRequest.NumRows
task.upsertMsg.InsertMsg.InsertRequest.NumRows = 10000
err = task.insertPreExecute(ctx)
assert.Error(t, err)
task.upsertMsg.InsertMsg.InsertRequest.NumRows = oldRows
}
}

View File

@ -37,6 +37,7 @@ import (
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
@ -705,6 +706,10 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
return err
}
}
if err := function.ValidateFunctions(coll); err != nil {
return err
}
return nil
}
@ -718,6 +723,10 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
}
case schemapb.FunctionType_TextEmbedding:
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_FloatVector {
return fmt.Errorf("TextEmbedding function output field must be a FloatVector field")
}
default:
return fmt.Errorf("check output field for unknown function type")
}
@ -744,7 +753,10 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema
if !h.EnableAnalyzer() {
return fmt.Errorf("BM25 function input field must set enable_analyzer to true")
}
case schemapb.FunctionType_TextEmbedding:
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar {
return fmt.Errorf("TextEmbedding function input field must be a VARCHAR field")
}
default:
return fmt.Errorf("check input field with unknown function type")
}
@ -786,6 +798,10 @@ func checkFunctionBasicParams(function *schemapb.FunctionSchema) error {
if len(function.GetParams()) != 0 {
return fmt.Errorf("BM25 function accepts no params")
}
case schemapb.FunctionType_TextEmbedding:
if len(function.GetParams()) == 0 {
return fmt.Errorf("TextEmbedding function need provider and model_name params")
}
default:
return fmt.Errorf("check function params with unknown function type")
}
@ -942,7 +958,7 @@ func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb
expectColumnNum := 0
for _, field := range schema.GetFields() {
fieldName2Schema[field.Name] = field
if !field.GetIsFunctionOutput() {
if !IsBM25FunctionOutputField(field, schema) {
expectColumnNum++
}
}
@ -1494,12 +1510,12 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst
if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey {
return merr.WrapErrParameterInvalidMsg("primary key can't be with default value")
}
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema, schema) {
// when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false
autoGenFieldNum++
}
if _, ok := dataNameSet[fieldSchema.GetName()]; !ok {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema, schema) {
// autoGenField
continue
}
@ -1523,7 +1539,6 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst
zap.Int64("primaryKeyNum", int64(primaryKeyNum)))
return merr.WrapErrParameterInvalidMsg("more than 1 primary keys not supported, got %d", primaryKeyNum)
}
expectedNum := len(schema.Fields)
actualNum := len(insertMsg.FieldsData) + autoGenFieldNum
@ -2207,3 +2222,21 @@ func GetReplicateID(ctx context.Context, database, collectionName string) (strin
replicateID, _ := common.GetReplicateID(dbInfo.properties)
return replicateID, nil
}
func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb.CollectionSchema) bool {
if !(field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector) {
return false
}
for _, fSchema := range collSchema.Functions {
if fSchema.Type == schemapb.FunctionType_BM25 {
if len(fSchema.OutputFieldNames) != 0 && field.Name == fSchema.OutputFieldNames[0] {
return true
}
if len(fSchema.OutputFieldIds) != 0 && field.FieldID == fSchema.OutputFieldIds[0] {
return true
}
}
}
return false
}

View File

@ -2848,6 +2848,79 @@ func TestValidateFunction(t *testing.T) {
})
}
func TestValidateModelFunction(t *testing.T) {
t.Run("Valid model function schema", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
{
Name: "output_dense_field", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
{
Name: "text_embedding_func",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_dense_field"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: "mock_url"},
{Key: "dim", Value: "4"},
},
},
},
}
err := validateFunction(schema)
assert.NoError(t, err)
})
t.Run("Invalid function schema - Invalid function info ", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
{Name: "output_dense_field", DataType: schemapb.DataType_FloatVector},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
{
Name: "text_embedding_func",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_dense_field"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "UnkownProvider"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "api_key", Value: "mock"},
{Key: "url", Value: "mock_url"},
{Key: "dim", Value: "4"},
},
},
},
}
err := validateFunction(schema)
assert.Error(t, err)
})
}
func TestValidateFunctionInputField(t *testing.T) {
t.Run("Valid BM25 function input", func(t *testing.T) {
function := &schemapb.FunctionSchema{
@ -2920,6 +2993,28 @@ func TestValidateFunctionInputField(t *testing.T) {
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid TextEmbedding function input - multiple fields", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_TextEmbedding,
}
fields := []*schemapb.FieldSchema{}
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid TextEmbedding function input - wrong type", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_TextEmbedding,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_Int64,
},
}
err := checkFunctionInputField(function, fields)
assert.Error(t, err)
})
}
func TestValidateFunctionOutputField(t *testing.T) {
@ -2977,6 +3072,28 @@ func TestValidateFunctionOutputField(t *testing.T) {
err := checkFunctionOutputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid TextEmbedding function input - multiple fields", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_TextEmbedding,
}
fields := []*schemapb.FieldSchema{}
err := checkFunctionOutputField(function, fields)
assert.Error(t, err)
})
t.Run("Invalid TextEmbedding function input - wrong type", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Type: schemapb.FunctionType_TextEmbedding,
}
fields := []*schemapb.FieldSchema{
{
DataType: schemapb.DataType_Int64,
},
}
err := checkFunctionOutputField(function, fields)
assert.Error(t, err)
})
}
func TestValidateFunctionBasicParams(t *testing.T) {
@ -3078,6 +3195,36 @@ func TestValidateFunctionBasicParams(t *testing.T) {
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
t.Run("Empty text embedding params", func(t *testing.T) {
function := &schemapb.FunctionSchema{
Name: "textEmbeddingParam",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"input1"},
OutputFieldNames: []string{"output1"},
}
err := checkFunctionBasicParams(function)
assert.Error(t, err)
})
}
func TestIsBM25FunctionOutputField(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsFunctionOutput: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25_func",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"input_field"},
OutputFieldNames: []string{"output_field"},
},
},
}
assert.False(t, IsBM25FunctionOutputField(schema.Fields[0], schema))
assert.True(t, IsBM25FunctionOutputField(schema.Fields[1], schema))
}
func TestComputeRecall(t *testing.T) {

View File

@ -70,6 +70,9 @@ func newEmbeddingNode(collectionID int64, channelName string, manager *DataManag
if err != nil {
return nil, err
}
if functionRunner == nil {
continue
}
node.functionRunners = append(node.functionRunners, functionRunner)
}
return node, nil

View File

@ -382,7 +382,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap
}
for _, field := range collSchema.Fields {
if skipFunction && field.GetIsFunctionOutput() {
if skipFunction && IsBM25FunctionOutputField(field, collSchema) {
continue
}
@ -527,7 +527,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
}
length := 0
for _, field := range collSchema.Fields {
if field.GetIsFunctionOutput() {
if IsBM25FunctionOutputField(field, collSchema) {
continue
}
@ -1405,3 +1405,22 @@ func (ni NullableInt) GetValue() int {
func (ni NullableInt) IsNull() bool {
return ni.Value == nil
}
// TODO: unify the function implementation, storage/utils.go & proxy/util.go
func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb.CollectionSchema) bool {
if !(field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector) {
return false
}
for _, fSchema := range collSchema.Functions {
if fSchema.Type == schemapb.FunctionType_BM25 {
if len(fSchema.OutputFieldNames) != 0 && field.Name == fSchema.OutputFieldNames[0] {
return true
}
if len(fSchema.OutputFieldIds) != 0 && field.FieldID == fSchema.OutputFieldIds[0] {
return true
}
}
}
return false
}

View File

@ -1949,3 +1949,53 @@ func TestJson(t *testing.T) {
t.Log(string(ExtraBytes))
t.Log(ExtraLength)
}
func TestBM25Checker(t *testing.T) {
f1 := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
}
f2 := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"sparse"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{103},
}
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector",
DataType: schemapb.DataType_FloatVector,
IsFunctionOutput: true,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{
FieldID: 103, Name: "sparse",
DataType: schemapb.DataType_SparseFloatVector,
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{f1, f2},
}
for _, field := range schema.Fields {
isBm25 := IsBM25FunctionOutputField(field, schema)
if field.FieldID == 103 {
assert.True(t, isBm25)
} else {
assert.False(t, isBm25)
}
}
}

View File

@ -0,0 +1,149 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"os"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/ali"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type AliEmbeddingProvider struct {
fieldDim int64
client *ali.AliDashScopeEmbedding
modelName string
embedDimParam int64
outputType string
maxBatch int
timeoutSec int64
}
func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
if apiKey == "" {
apiKey = os.Getenv(dashscopeAKEnvStr)
}
if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
}
if url == "" {
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
c := ali.NewAliDashScopeEmbeddingClient(apiKey, url)
return c, nil
}
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil {
return nil, err
}
var apiKey, url, modelName string
var dim int64
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case dimParamKey:
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
if err != nil {
return nil, err
}
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
default:
}
}
if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3)
}
c, err := createAliEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
}
maxBatch := 25
if modelName == TextEmbeddingV3 {
maxBatch = 6
}
provider := AliEmbeddingProvider{
client: c,
fieldDim: fieldDim,
modelName: modelName,
embedDimParam: dim,
// TextEmbedding only supports dense embedding
outputType: "dense",
maxBatch: maxBatch,
timeoutSec: 30,
}
return &provider, nil
}
func (provider *AliEmbeddingProvider) MaxBatch() int {
return 5 * provider.maxBatch
}
func (provider *AliEmbeddingProvider) FieldDim() int64 {
return provider.fieldDim
}
func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
numRows := len(texts)
var textType string
if mode == SearchMode {
textType = "query"
} else {
textType = "document"
}
data := make([][]float32, 0, numRows)
for i := 0; i < numRows; i += provider.maxBatch {
end := i + provider.maxBatch
if end > numRows {
end = numRows
}
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec)
if err != nil {
return nil, err
}
if end-i != len(resp.Output.Embeddings) {
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Output.Embeddings))
}
for _, item := range resp.Output.Embeddings {
if len(item.Embedding) != int(provider.fieldDim) {
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
provider.fieldDim, len(item.Embedding))
}
data = append(data, item.Embedding)
}
}
return data, nil
}

View File

@ -0,0 +1,202 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/ali"
)
func TestAliTextEmbeddingProvider(t *testing.T) {
suite.Run(t, new(AliTextEmbeddingProviderSuite))
}
type AliTextEmbeddingProviderSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *AliTextEmbeddingProviderSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
s.providers = []string{aliDashScopeProvider}
}
func createAliProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TextEmbeddingV3},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "4"},
},
}
switch providerName {
case aliDashScopeProvider:
return NewAliDashScopeEmbeddingProvider(schema, functionSchema)
default:
return nil, fmt.Errorf("Unknow provider")
}
}
func (s *AliTextEmbeddingProviderSuite) TestEmbedding() {
ts := CreateAliEmbeddingServer()
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
{
data := []string{"sentence"}
ret, err2 := provder.CallEmbedding(data, InsertMode)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(4, len(ret[0]))
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
}
{
data := []string{"sentence 1", "sentence 2", "sentence 3"}
ret, _ := provder.CallEmbedding(data, SearchMode)
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
}
}
}
func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res ali.EmbeddingResponse
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
TextIndex: 0,
})
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: []float32{1.0, 1.0},
TextIndex: 1,
})
res.Usage = ali.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
for _, providerName := range s.providers {
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], providerName)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res ali.EmbeddingResponse
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
TextIndex: 0,
})
res.Usage = ali.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence2"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *AliTextEmbeddingProviderSuite) TestCreateAliEmbeddingClient() {
_, err := createAliEmbeddingClient("", "")
s.Error(err)
os.Setenv(dashscopeAKEnvStr, "mock_key")
defer os.Unsetenv(dashscopeAKEnvStr)
_, err = createAliEmbeddingClient("", "")
s.NoError(err)
}
func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: "UnkownModels"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
},
}
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
// invalid dim
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TextEmbeddingV3}
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err = NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}

View File

@ -0,0 +1,202 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type BedrockClient interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}
type BedrockEmbeddingProvider struct {
fieldDim int64
client BedrockClient
modelName string
embedDimParam int64
normalize bool
maxBatch int
timeoutSec int
}
func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) {
if awsAccessKeyId == "" {
awsAccessKeyId = os.Getenv(bedrockAccessKeyId)
}
if awsAccessKeyId == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId)
}
if awsSecretAccessKey == "" {
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
}
if awsSecretAccessKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr)
}
if region == "" {
return nil, fmt.Errorf("Missing AWS Service region. Please pass `region` param.")
}
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
awsAccessKeyId, awsSecretAccessKey, "")),
)
if err != nil {
return nil, err
}
return bedrockruntime.NewFromConfig(cfg), nil
}
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil {
return nil, err
}
var awsAccessKeyId, awsSecretAccessKey, region, modelName string
var dim int64
normalize := true
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case dimParamKey:
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
if err != nil {
return nil, err
}
case awsAKIdParamKey:
awsAccessKeyId = param.Value
case awsSAKParamKey:
awsSecretAccessKey = param.Value
case regionParamKey:
region = param.Value
case normalizeParamKey:
switch strings.ToLower(param.Value) {
case "false":
normalize = false
case "true":
normalize = true
default:
return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value)
}
default:
}
}
if modelName != BedRockTitanTextEmbeddingsV2 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s]",
modelName, BedRockTitanTextEmbeddingsV2)
}
var client BedrockClient
if c == nil {
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)
if err != nil {
return nil, err
}
} else {
client = c
}
return &BedrockEmbeddingProvider{
client: client,
fieldDim: fieldDim,
modelName: modelName,
embedDimParam: dim,
normalize: normalize,
maxBatch: 1,
timeoutSec: 30,
}, nil
}
func (provider *BedrockEmbeddingProvider) MaxBatch() int {
// The bedrock model does not support batches, we support a small batch on the milvus side.
return 12 * provider.maxBatch
}
func (provider *BedrockEmbeddingProvider) FieldDim() int64 {
return provider.fieldDim
}
func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, _ TextEmbeddingMode) ([][]float32, error) {
numRows := len(texts)
data := make([][]float32, 0, numRows)
for i := 0; i < numRows; i += 1 {
payload := BedRockRequest{
InputText: texts[i],
Normalize: provider.normalize,
}
if provider.embedDimParam != 0 {
payload.Dimensions = provider.embedDimParam
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
output, err := provider.client.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{
Body: payloadBytes,
ModelId: aws.String(provider.modelName),
ContentType: aws.String("application/json"),
})
if err != nil {
return nil, err
}
var resp BedRockResponse
err = json.Unmarshal(output.Body, &resp)
if err != nil {
return nil, err
}
if len(resp.Embedding) != int(provider.fieldDim) {
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
provider.fieldDim, len(resp.Embedding))
}
data = append(data, resp.Embedding)
}
return data, nil
}
type BedRockRequest struct {
InputText string `json:"inputText"`
Dimensions int64 `json:"dimensions,omitempty"`
Normalize bool `json:"normalize,omitempty"`
}
type BedRockResponse struct {
Embedding []float32 `json:"embedding"`
InputTextTokenCount int `json:"inputTextTokenCount"`
}

View File

@ -0,0 +1,177 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func TestBedrockTextEmbeddingProvider(t *testing.T) {
suite.Run(t, new(BedrockTextEmbeddingProviderSuite))
}
type BedrockTextEmbeddingProviderSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *BedrockTextEmbeddingProviderSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
s.providers = []string{bedrockProvider}
}
func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, dim int) (textEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
},
}
switch providerName {
case bedrockProvider:
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim})
default:
return nil, fmt.Errorf("Unknow provider")
}
}
func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() {
for _, provderName := range s.providers {
provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 4)
s.NoError(err)
{
data := []string{"sentence"}
ret, err2 := provder.CallEmbedding(data, InsertMode)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(4, len(ret[0]))
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
}
{
data := []string{"sentence 1", "sentence 2", "sentence 3"}
ret, _ := provder.CallEmbedding(data, SearchMode)
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret)
}
}
}
func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
for _, provderName := range s.providers {
provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 2)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *BedrockTextEmbeddingProviderSuite) TestCreateBedrockClient() {
_, err := createBedRockEmbeddingClient("", "", "")
s.Error(err)
_, err = createBedRockEmbeddingClient("mock_id", "", "")
s.Error(err)
_, err = createBedRockEmbeddingClient("", "mock_key", "")
s.Error(err)
_, err = createBedRockEmbeddingClient("mock_id", "mock_key", "")
s.Error(err)
_, err = createBedRockEmbeddingClient("mock_id", "mock_key", "mock_region")
s.NoError(err)
}
func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
fieldSchema := &schemapb.FieldSchema{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: awsAKIdParamKey, Value: "mock"},
{Key: awsSAKParamKey, Value: "mock"},
{Key: regionParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
{Key: normalizeParamKey, Value: "false"},
},
}
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.NoError(err)
s.True(provider.MaxBatch() > 0)
s.Equal(provider.FieldDim(), int64(4))
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.NoError(err)
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.Error(err)
// invalid model name
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.Error(err)
// invalid dim
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
s.Error(err)
}

View File

@ -0,0 +1,114 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"strconv"
)
type TextEmbeddingMode int
const (
InsertMode TextEmbeddingMode = iota
SearchMode
)
// common params
const (
modelNameParamKey string = "model_name"
dimParamKey string = "dim"
embeddingURLParamKey string = "url"
apiKeyParamKey string = "api_key"
)
// ali text embedding
const (
TextEmbeddingV1 string = "text-embedding-v1"
TextEmbeddingV2 string = "text-embedding-v2"
TextEmbeddingV3 string = "text-embedding-v3"
dashscopeAKEnvStr string = "MILVUSAI_DASHSCOPE_API_KEY"
)
// openai/azure text embedding
const (
TextEmbeddingAda002 string = "text-embedding-ada-002"
TextEmbedding3Small string = "text-embedding-3-small"
TextEmbedding3Large string = "text-embedding-3-large"
openaiAKEnvStr string = "MILVUSAI_OPENAI_API_KEY"
azureOpenaiAKEnvStr string = "MILVUSAI_AZURE_OPENAI_API_KEY"
azureOpenaiResourceName string = "MILVUSAI_AZURE_OPENAI_RESOURCE_NAME"
userParamKey string = "user"
)
// bedrock emebdding
const (
BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0"
awsAKIdParamKey string = "aws_access_key_id"
awsSAKParamKey string = "aws_secret_access_key"
regionParamKey string = "regin"
normalizeParamKey string = "normalize"
bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID"
bedrockSAKEnvStr string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY"
)
// vertexAI
const (
locationParamKey string = "location"
projectIDParamKey string = "projectid"
taskTypeParamKey string = "task"
textEmbedding005 string = "text-embedding-005"
textMultilingualEmbedding002 string = "text-multilingual-embedding-002"
vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS"
)
// voyageAI
const (
voyage3Large string = "voyage-3-large"
voyage3 string = "voyage-3"
voyage3Lite string = "voyage-3-lite"
voyageCode3 string = "voyage-code-3"
voyageFinance2 string = "voyage-finance-2"
voyageLaw2 string = "voyage-law-2"
voyageCode2 string = "voyage-code-2"
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
)
func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("dimension [%s] provided in Function params is not a valid int", dimStr)
}
if dim != 0 && dim != fieldDim {
return 0, fmt.Errorf("Function output field:[%s]'s dimension [%d] does not match the dimension [%d] provided in Function params.", fieldName, fieldDim, dim)
}
return dim, nil
}

View File

@ -35,6 +35,8 @@ func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Functio
switch schema.GetType() {
case schemapb.FunctionType_BM25:
return NewBM25FunctionRunner(coll, schema)
case schemapb.FunctionType_TextEmbedding:
return nil, nil
default:
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
}

View File

@ -0,0 +1,57 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type FunctionBase struct {
schema *schemapb.FunctionSchema
outputFields []*schemapb.FieldSchema
}
func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) {
var base FunctionBase
base.schema = fSchema
for _, fieldName := range fSchema.GetOutputFieldNames() {
for _, field := range coll.GetFields() {
if field.GetName() == fieldName {
base.outputFields = append(base.outputFields, field)
break
}
}
}
if len(base.outputFields) != len(fSchema.GetOutputFieldNames()) {
return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema",
coll.Name, fSchema.Name)
}
return &base, nil
}
func (base *FunctionBase) GetSchema() *schemapb.FunctionSchema {
return base.schema
}
func (base *FunctionBase) GetOutputFields() []*schemapb.FieldSchema {
return base.outputFields
}

View File

@ -0,0 +1,255 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"sync"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type Runner interface {
GetSchema() *schemapb.FunctionSchema
GetOutputFields() []*schemapb.FieldSchema
MaxBatch() int
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
}
type FunctionExecutor struct {
runners map[int64]Runner
}
func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (Runner, error) {
switch schema.GetType() {
case schemapb.FunctionType_BM25: // ignore bm25 function
return nil, nil
case schemapb.FunctionType_TextEmbedding:
f, err := NewTextEmbeddingFunction(coll, schema)
if err != nil {
return nil, err
}
return f, nil
default:
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
}
}
// Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here.
func ValidateFunctions(schema *schemapb.CollectionSchema) error {
for _, fSchema := range schema.Functions {
if _, err := createFunction(schema, fSchema); err != nil {
return err
}
}
return nil
}
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
executor := &FunctionExecutor{
runners: make(map[int64]Runner),
}
for _, fSchema := range schema.Functions {
runner, err := createFunction(schema, fSchema)
if err != nil {
return nil, err
}
if runner != nil {
executor.runners[fSchema.GetOutputFieldIds()[0]] = runner
}
}
return executor, nil
}
func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames()))
for _, name := range runner.GetSchema().GetInputFieldNames() {
for _, field := range msg.FieldsData {
if field.GetFieldName() == name {
inputs = append(inputs, field)
}
}
}
if len(inputs) != len(runner.GetSchema().InputFieldIds) {
return nil, fmt.Errorf("Input field not found")
}
outputs, err := runner.ProcessInsert(inputs)
if err != nil {
return nil, err
}
return outputs, nil
}
func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error {
numRows := msg.NumRows
for _, runner := range executor.runners {
if numRows > uint64(runner.MaxBatch()) {
return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch())
}
}
outputs := make(chan []*schemapb.FieldData, len(executor.runners))
errChan := make(chan error, len(executor.runners))
var wg sync.WaitGroup
for _, runner := range executor.runners {
wg.Add(1)
go func(runner Runner) {
defer wg.Done()
data, err := executor.processSingleFunction(runner, msg)
if err != nil {
errChan <- err
return
}
outputs <- data
}(runner)
}
wg.Wait()
close(errChan)
close(outputs)
// Collect all errors
var errs []error
for err := range errChan {
errs = append(errs, err)
}
if len(errs) > 0 {
return fmt.Errorf("multiple errors occurred: %v", errs)
}
for output := range outputs {
msg.FieldsData = append(msg.FieldsData, output...)
}
return nil
}
func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) {
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroup, pb)
if len(pb.Placeholders) != 1 {
return nil, merr.WrapErrParameterInvalidMsg("No placeholders founded")
}
if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar {
return placeholderGroup, nil
}
res, err := runner.ProcessSearch(pb)
if err != nil {
return nil, err
}
return proto.Marshal(res)
}
func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error {
runner, exist := executor.runners[req.FieldId]
if !exist {
return fmt.Errorf("Can not found function in field %d", req.FieldId)
}
if req.Nq > int64(runner.MaxBatch()) {
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch())
}
if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil {
return err
} else {
req.PlaceholderGroup = newHolder
}
return nil
}
func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error {
outputs := make(chan map[int64][]byte, len(req.GetSubReqs()))
errChan := make(chan error, len(req.GetSubReqs()))
var wg sync.WaitGroup
for idx, sub := range req.GetSubReqs() {
if runner, exist := executor.runners[sub.FieldId]; exist {
if sub.Nq > int64(runner.MaxBatch()) {
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch())
}
wg.Add(1)
go func(runner Runner, idx int64, placeholderGroup []byte) {
defer wg.Done()
if newHolder, err := executor.processSingleSearch(runner, placeholderGroup); err != nil {
errChan <- err
} else {
outputs <- map[int64][]byte{idx: newHolder}
}
}(runner, int64(idx), sub.GetPlaceholderGroup())
}
}
wg.Wait()
close(errChan)
close(outputs)
for err := range errChan {
return err
}
for output := range outputs {
for idx, holder := range output {
req.SubReqs[idx].PlaceholderGroup = holder
}
}
return nil
}
func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error {
if !req.IsAdvanced {
return executor.prcessSearch(req)
}
return executor.prcessAdvanceSearch(req)
}
func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {
inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds))
for idx, id := range runner.GetSchema().InputFieldIds {
field, exist := data.Data[id]
if !exist {
return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx])
}
inputs = append(inputs, field)
}
outputs, err := runner.ProcessBulkInsert(inputs)
if err != nil {
return nil, err
}
return outputs, nil
}
func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error {
// Since concurrency has already been used in the outer layer, only a serial logic access model is used here.
for _, runner := range executor.runners {
output, err := executor.processSingleBulkInsert(runner, data)
if err != nil {
return nil
}
for k, v := range output {
data.Data[k] = v
}
}
return nil
}

View File

@ -0,0 +1,340 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/openai"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
func TestFunctionExecutor(t *testing.T) {
suite.Run(t, new(FunctionExecutorSuite))
}
type FunctionExecutorSuite struct {
suite.Suite
}
func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema {
return &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
IsFunctionOutput: true,
},
{
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "8"},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "4"},
},
},
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{103},
OutputFieldNames: []string{"vector2"},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "8"},
},
},
},
}
}
func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg {
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
FieldName: "text",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: texts,
},
},
},
},
}
data = append(data, &f)
msg := msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
FieldsData: data,
},
}
return &msg
}
func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]float32 {
embeddings := make([][]float32, 0)
for i := 0; i < len(texts); i++ {
f := float32(i)
emb := make([]float32, 0)
for j := 0; j < dim; j++ {
emb = append(emb, f+float32(j)*0.1)
}
embeddings = append(embeddings, emb)
}
return embeddings
}
func (s *FunctionExecutorSuite) TestExecutor() {
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
exec.ProcessInsert(msg)
s.Equal(len(msg.FieldsData), 3)
}
func (s *FunctionExecutorSuite) TestErrorEmbedding() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req openai.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
var res openai.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, openai.EmbeddingData{
Object: "embedding",
Embedding: []float32{},
Index: i,
})
}
res.Usage = openai.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema)
fmt.Println(err)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
err = exec.ProcessInsert(msg)
s.Error(err)
}
func (s *FunctionExecutorSuite) TestErrorSchema() {
schema := s.creataSchema("http://localhost")
schema.Functions[0].Type = schemapb.FunctionType_Unknown
_, err := NewFunctionExecutor(schema)
s.Error(err)
}
func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
{
f := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strings.Split("helle,world", ","),
},
},
},
},
}
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
s.NoError(err)
req := &internalpb.SearchRequest{
Nq: 2,
PlaceholderGroup: placeholderGroupBytes,
IsAdvanced: false,
FieldId: 102,
}
err = exec.ProcessSearch(req)
s.NoError(err)
// No function found
req = &internalpb.SearchRequest{
Nq: 2,
PlaceholderGroup: placeholderGroupBytes,
IsAdvanced: false,
FieldId: 111,
}
err = exec.ProcessSearch(req)
s.Error(err)
// Large search nq
req = &internalpb.SearchRequest{
Nq: 1000,
PlaceholderGroup: placeholderGroupBytes,
IsAdvanced: false,
FieldId: 102,
}
err = exec.ProcessSearch(req)
s.Error(err)
}
// AdvanceSearch
{
f := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strings.Split("helle,world", ","),
},
},
},
},
}
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
s.NoError(err)
subReq := &internalpb.SubSearchRequest{
PlaceholderGroup: placeholderGroupBytes,
Nq: 2,
FieldId: 102,
}
req := &internalpb.SearchRequest{
IsAdvanced: true,
SubReqs: []*internalpb.SubSearchRequest{subReq},
}
err = exec.ProcessSearch(req)
s.NoError(err)
// Large nq
subReq.Nq = 1000
err = exec.ProcessSearch(req)
s.Error(err)
}
}
func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
ts := CreateErrorEmbeddingServer()
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
f := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strings.Split("helle,world", ","),
},
},
},
},
}
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
s.NoError(err)
{
req := &internalpb.SearchRequest{
Nq: 2,
PlaceholderGroup: placeholderGroupBytes,
IsAdvanced: false,
FieldId: 102,
}
err = exec.ProcessSearch(req)
s.Error(err)
}
// AdvanceSearch
{
subReq := &internalpb.SubSearchRequest{
PlaceholderGroup: placeholderGroupBytes,
Nq: 2,
FieldId: 102,
}
req := &internalpb.SearchRequest{
IsAdvanced: true,
SubReqs: []*internalpb.SubSearchRequest{subReq},
}
err = exec.ProcessSearch(req)
s.Error(err)
}
}

View File

@ -0,0 +1,47 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
// Determine whether the column corresponding to outputIDs contains functions, except bm25 function,
// if outputIDs is empty, check all cols
func HasNonBM25Functions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool {
for _, fSchema := range functions {
switch fSchema.GetType() {
case schemapb.FunctionType_BM25:
case schemapb.FunctionType_Unknown:
default:
if len(outputIDs) == 0 {
return true
} else {
for _, id := range outputIDs {
for _, fOutputID := range fSchema.GetOutputFieldIds() {
if fOutputID == id {
return true
}
}
}
}
}
}
return false
}

View File

@ -0,0 +1,184 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/milvus-io/milvus/internal/util/function/models/ali"
"github.com/milvus-io/milvus/internal/util/function/models/openai"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
)
func mockEmbedding(texts []string, dim int) [][]float32 {
embeddings := make([][]float32, 0)
for i := 0; i < len(texts); i++ {
f := float32(i)
emb := make([]float32, 0)
for j := 0; j < dim; j++ {
emb = append(emb, f+float32(j)*0.1)
}
embeddings = append(embeddings, emb)
}
return embeddings
}
func CreateErrorEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
return ts
}
func CreateOpenAIEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req openai.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
embs := mockEmbedding(req.Input, req.Dimensions)
var res openai.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, openai.EmbeddingData{
Object: "embedding",
Embedding: embs[i],
Index: i,
})
}
res.Usage = openai.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
return ts
}
func CreateAliEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req ali.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
embs := mockEmbedding(req.Input.Texts, req.Parameters.Dimension)
var res ali.EmbeddingResponse
for i := 0; i < len(req.Input.Texts); i++ {
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: embs[i],
TextIndex: i,
})
}
res.Usage = ali.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
return ts
}
func CreateVoyageAIEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req voyageai.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
embs := mockEmbedding(req.Input, int(req.OutputDimension))
var res voyageai.EmbeddingResponse
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, voyageai.EmbeddingData{
Object: "list",
Embedding: embs[i],
Index: i,
})
}
res.Usage = voyageai.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
return ts
}
func CreateVertexAIEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req vertexai.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
var texts []string
for _, item := range req.Instances {
texts = append(texts, item.Content)
}
embs := mockEmbedding(texts, int(req.Parameters.OutputDimensionality))
var res vertexai.EmbeddingResponse
for i := 0; i < len(req.Instances); i++ {
res.Predictions = append(res.Predictions, vertexai.Prediction{
Embeddings: vertexai.Embeddings{
Statistics: vertexai.Statistics{
Truncated: false,
TokenCount: 10,
},
Values: embs[i],
},
})
}
res.Metadata = vertexai.Metadata{
BillableCharacterCount: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
return ts
}
type MockBedrockClient struct {
dim int
}
func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) {
var req BedRockRequest
json.Unmarshal(params.Body, &req)
embs := mockEmbedding([]string{req.InputText}, c.dim)
var resp BedRockResponse
resp.Embedding = embs[0]
resp.InputTextTokenCount = 2
body, _ := json.Marshal(resp)
return &bedrockruntime.InvokeModelOutput{Body: body}, nil
}

View File

@ -0,0 +1,156 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ali
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"time"
"github.com/milvus-io/milvus/internal/util/function/models/utils"
)
type Input struct {
Texts []string `json:"texts"`
}
type Parameters struct {
TextType string `json:"text_type,omitempty"`
Dimension int `json:"dimension,omitempty"`
OutputType string `json:"output_type,omitempty"`
}
type EmbeddingRequest struct {
// ID of the model to use.
Model string `json:"model"`
// Input text to embed, encoded as a string.
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type Usage struct {
// The total number of tokens used by the request.
TotalTokens int `json:"total_tokens"`
}
type SparseEmbedding struct {
Index int `json:"index"`
Value float32 `json:"value"`
Token string `json:"token"`
}
type Embeddings struct {
TextIndex int `json:"text_index"`
Embedding []float32 `json:"embedding,omitempty"`
SparseEmbedding []SparseEmbedding `json:"sparse_embedding,omitempty"`
}
type Output struct {
Embeddings []Embeddings `json:"embeddings"`
}
type EmbeddingResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
RequestID string `json:"request_id"`
}
type ByIndex struct {
resp *EmbeddingResponse
}
func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) }
func (eb *ByIndex) Swap(i, j int) {
eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i]
}
func (eb *ByIndex) Less(i, j int) bool {
return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex
}
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
type AliDashScopeEmbedding struct {
apiKey string
url string
}
func NewAliDashScopeEmbeddingClient(apiKey string, url string) *AliDashScopeEmbedding {
return &AliDashScopeEmbedding{
apiKey: apiKey,
url: url,
}
}
func (c *AliDashScopeEmbedding) Check() error {
if c.apiKey == "" {
return fmt.Errorf("api key is empty")
}
if c.url == "" {
return fmt.Errorf("url is empty")
}
return nil
}
func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = modelName
r.Input = Input{texts}
r.Parameters.Dimension = dim
r.Parameters.TextType = textType
r.Parameters.OutputType = outputType
data, err := json.Marshal(r)
if err != nil {
return nil, err
}
if timeoutSec <= 0 {
timeoutSec = utils.DefaultTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
}
sort.Sort(&ByIndex{&res})
return &res, err
}

View File

@ -0,0 +1,116 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ali
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEmbeddingClientCheck(t *testing.T) {
{
c := NewAliDashScopeEmbeddingClient("", "mock_uri")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewAliDashScopeEmbeddingClient("mock_key", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewAliDashScopeEmbeddingClient("mock_key", "mock_uri")
err := c.Check()
assert.True(t, err == nil)
}
}
func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
repStr := `{
"output": {
"embeddings": [
{
"text_index": 1,
"embedding": [0.1]
},
{
"text_index": 0,
"embedding": [0.0]
},
{
"text_index": 2,
"embedding": [0.2]
}
]
},
"usage": {
"total_tokens": 100
},
"request_id": "0000000000000"
}`
err := json.Unmarshal([]byte(repStr), &res)
assert.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
url := ts.URL
{
c := NewAliDashScopeEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Output.Embeddings[0].TextIndex, 0)
assert.Equal(t, ret.Output.Embeddings[1].TextIndex, 1)
assert.Equal(t, ret.Output.Embeddings[2].TextIndex, 2)
assert.Equal(t, ret.Output.Embeddings[0].Embedding, []float32{0.0})
assert.Equal(t, ret.Output.Embeddings[1].Embedding, []float32{0.1})
assert.Equal(t, ret.Output.Embeddings[2].Embedding, []float32{0.2})
}
}
func TestEmbeddingFailed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer ts.Close()
url := ts.URL
{
c := NewAliDashScopeEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0)
assert.True(t, err != nil)
}
}

View File

@ -0,0 +1,225 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"time"
"github.com/milvus-io/milvus/internal/util/function/models/utils"
)
type EmbeddingRequest struct {
// ID of the model to use.
Model string `json:"model"`
// Input text to embed, encoded as a string.
Input []string `json:"input"`
// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
User string `json:"user,omitempty"`
// The format to return the embeddings in. Can be either float or base64.
EncodingFormat string `json:"encoding_format,omitempty"`
// The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
}
type Usage struct {
// The number of tokens used by the prompt.
PromptTokens int `json:"prompt_tokens"`
// The total number of tokens used by the request.
TotalTokens int `json:"total_tokens"`
}
type EmbeddingData struct {
// The object type, which is always "embedding".
Object string `json:"object"`
// The embedding vector, which is a list of floats.
Embedding []float32 `json:"embedding"`
// The index of the embedding in the list of embeddings.
Index int `json:"index"`
}
type EmbeddingResponse struct {
// The object type, which is always "list".
Object string `json:"object"`
// The list of embeddings generated by the model.
Data []EmbeddingData `json:"data"`
// The name of the model used to generate the embedding.
Model string `json:"model"`
// The usage information for the request.
Usage Usage `json:"usage"`
}
type ByIndex struct {
resp *EmbeddingResponse
}
func (eb *ByIndex) Len() int { return len(eb.resp.Data) }
func (eb *ByIndex) Swap(i, j int) {
eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i]
}
func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index }
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
Param string `json:"param,omitempty"`
Type string `json:"type"`
}
type EmbedddingError struct {
Error ErrorInfo `json:"error"`
}
type OpenAIEmbeddingInterface interface {
Check() error
Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error)
}
type openAIBase struct {
apiKey string
url string
}
func (c *openAIBase) Check() error {
if c.apiKey == "" {
return fmt.Errorf("api key is empty")
}
if c.url == "" {
return fmt.Errorf("url is empty")
}
return nil
}
func (c *openAIBase) genReq(modelName string, texts []string, dim int, user string) *EmbeddingRequest {
var r EmbeddingRequest
r.Model = modelName
r.Input = texts
r.EncodingFormat = "float"
if user != "" {
r.User = user
}
if dim != 0 {
r.Dimensions = dim
}
return &r
}
func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
r := c.genReq(modelName, texts, dim, user)
data, err := json.Marshal(r)
if err != nil {
return nil, err
}
if timeoutSec <= 0 {
timeoutSec = utils.DefaultTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
for key, value := range headers {
req.Header.Set(key, value)
}
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
}
sort.Sort(&ByIndex{&res})
return &res, err
}
type OpenAIEmbeddingClient struct {
openAIBase
}
func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient {
return &OpenAIEmbeddingClient{
openAIBase{
apiKey: apiKey,
url: url,
},
}
}
func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec)
}
type AzureOpenAIEmbeddingClient struct {
openAIBase
apiVersion string
}
func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbeddingClient {
return &AzureOpenAIEmbeddingClient{
openAIBase: openAIBase{
apiKey: apiKey,
url: url,
},
apiVersion: "2024-06-01",
}
}
func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
base, err := url.Parse(c.url)
if err != nil {
return nil, err
}
path := fmt.Sprintf("/openai/deployments/%s/embeddings", modelName)
base.Path = path
params := url.Values{}
params.Add("api-version", c.apiVersion)
base.RawQuery = params.Encode()
url := base.String()
headers := map[string]string{
"Content-Type": "application/json",
"api-key": c.apiKey,
}
return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec)
}

View File

@ -0,0 +1,252 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package openai
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestEmbeddingClientCheck(t *testing.T) {
{
c := NewOpenAIEmbeddingClient("", "mock_uri")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewOpenAIEmbeddingClient("mock_key", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewOpenAIEmbeddingClient("mock_key", "mock_uri")
err := c.Check()
assert.True(t, err == nil)
}
}
func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
res.Data = []EmbeddingData{
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
Index: 1,
},
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
Index: 0,
},
}
res.Usage = Usage{
PromptTokens: 1,
TotalTokens: 100,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
if r.Header["Authorization"][0] != "" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusBadRequest)
}
} else {
if r.Header["Api-Key"][0] != "" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusBadRequest)
}
}
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
url := ts.URL
{
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
assert.Equal(t, ret.Data[1].Index, 1)
}
{
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
assert.Equal(t, ret.Data[1].Index, 1)
}
}
func TestEmbeddingRetry(t *testing.T) {
var res EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
res.Data = []EmbeddingData{
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.2, 4.5},
Index: 2,
},
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
Index: 0,
},
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.2, 4.3},
Index: 1,
},
}
res.Usage = Usage{
PromptTokens: 1,
TotalTokens: 100,
}
var count int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.LoadInt32(&count) < 2 {
atomic.AddInt32(&count, 1)
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}
}))
defer ts.Close()
url := ts.URL
{
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Usage, res.Usage)
assert.Equal(t, ret.Object, res.Object)
assert.Equal(t, ret.Model, res.Model)
assert.Equal(t, ret.Data[0], res.Data[1])
assert.Equal(t, ret.Data[1], res.Data[2])
assert.Equal(t, ret.Data[2], res.Data[0])
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
}
{
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Usage, res.Usage)
assert.Equal(t, ret.Object, res.Object)
assert.Equal(t, ret.Model, res.Model)
assert.Equal(t, ret.Data[0], res.Data[1])
assert.Equal(t, ret.Data[1], res.Data[2])
assert.Equal(t, ret.Data[2], res.Data[0])
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
}
}
func TestEmbeddingFailed(t *testing.T) {
var count int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&count, 1)
w.WriteHeader(http.StatusUnauthorized)
}))
defer ts.Close()
url := ts.URL
{
atomic.StoreInt32(&count, 0)
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
}
{
atomic.StoreInt32(&count, 0)
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
}
}
func TestTimeout(t *testing.T) {
var st int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(3 * time.Second)
atomic.AddInt32(&st, 1)
w.WriteHeader(http.StatusUnauthorized)
}))
defer ts.Close()
url := ts.URL
{
atomic.StoreInt32(&st, 0)
c := NewOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&st), int32(0))
time.Sleep(3 * time.Second)
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
}
{
atomic.StoreInt32(&st, 0)
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1)
assert.True(t, err != nil)
assert.Equal(t, atomic.LoadInt32(&st), int32(0))
time.Sleep(3 * time.Second)
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
}
}

View File

@ -0,0 +1,55 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"fmt"
"io"
"net/http"
)
const DefaultTimeout int64 = 30
func send(req *http.Request) ([]byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(string(body))
}
return body, nil
}
func RetrySend(req *http.Request, maxRetries int) ([]byte, error) {
var err error
var res []byte
for i := 0; i < maxRetries; i++ {
res, err = send(req)
if err == nil {
return res, nil
}
}
return nil, err
}

View File

@ -0,0 +1,163 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vertexai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"golang.org/x/oauth2/google"
"github.com/milvus-io/milvus/internal/util/function/models/utils"
)
type Instance struct {
TaskType string `json:"task_type,omitempty"`
Content string `json:"content"`
}
type Parameters struct {
OutputDimensionality int64 `json:"outputDimensionality,omitempty"`
}
type EmbeddingRequest struct {
Instances []Instance `json:"instances"`
Parameters Parameters `json:"parameters,omitempty"`
}
type Statistics struct {
Truncated bool `json:"truncated"`
TokenCount int `json:"token_count"`
}
type Embeddings struct {
Statistics Statistics `json:"statistics"`
Values []float32 `json:"values"`
}
type Prediction struct {
Embeddings Embeddings `json:"embeddings"`
}
type Metadata struct {
BillableCharacterCount int `json:"billableCharacterCount"`
}
type EmbeddingResponse struct {
Predictions []Prediction `json:"predictions"`
Metadata Metadata `json:"metadata"`
}
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
type VertexAIEmbedding struct {
url string
jsonKey []byte
scopes string
token string
}
func NewVertexAIEmbedding(url string, jsonKey []byte, scopes string, token string) *VertexAIEmbedding {
return &VertexAIEmbedding{
url: url,
jsonKey: jsonKey,
scopes: scopes,
token: token,
}
}
func (c *VertexAIEmbedding) Check() error {
if c.url == "" {
return fmt.Errorf("VertexAI embedding url is empty")
}
if len(c.jsonKey) == 0 {
return fmt.Errorf("jsonKey is empty")
}
if c.scopes == "" {
return fmt.Errorf("Scopes param is empty")
}
return nil
}
func (c *VertexAIEmbedding) getAccessToken() (string, error) {
ctx := context.Background()
creds, err := google.CredentialsFromJSON(ctx, c.jsonKey, c.scopes)
if err != nil {
return "", fmt.Errorf("Failed to find credentials: %v", err)
}
token, err := creds.TokenSource.Token()
if err != nil {
return "", fmt.Errorf("Failed to get token: %v", err)
}
return token.AccessToken, nil
}
func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
for _, text := range texts {
r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text})
}
if dim != 0 {
r.Parameters.OutputDimensionality = dim
}
data, err := json.Marshal(r)
if err != nil {
return nil, err
}
if timeoutSec <= 0 {
timeoutSec = utils.DefaultTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
var token string
if c.token != "" {
token = c.token
} else {
token, err = c.getAccessToken()
if err != nil {
return nil, err
}
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
}
return &res, err
}

View File

@ -0,0 +1,90 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vertexai
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEmbeddingClientCheck(t *testing.T) {
mockJSONKey := []byte{1, 2, 3}
{
c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewVertexAIEmbedding("", mockJSONKey, "", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "")
err := c.Check()
assert.True(t, err == nil)
}
}
func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
repStr := `{"predictions": [{"embeddings": {"statistics": {"truncated": false, "token_count": 4}, "values": [-0.028420744463801384, 0.037183016538619995]}}, {"embeddings": {"statistics": {"truncated": false, "token_count": 8}, "values": [-0.04367655888199806, 0.03777721896767616, 0.0158217903226614]}}], "metadata": {"billableCharacterCount": 27}}`
err := json.Unmarshal([]byte(repStr), &res)
assert.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
url := ts.URL
{
c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token")
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-005", []string{"sentence"}, 0, "query", 0)
assert.True(t, err == nil)
}
}
func TestEmbeddingFailed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer ts.Close()
url := ts.URL
{
c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token")
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", 0)
assert.True(t, err != nil)
}
}

View File

@ -0,0 +1,152 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package voyageai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"time"
"github.com/milvus-io/milvus/internal/util/function/models/utils"
)
type EmbeddingRequest struct {
// ID of the model to use.
Model string `json:"model"`
// Input text to embed, encoded as a string.
Input []string `json:"input"`
InputType string `json:"input_type,omitempty"`
Truncation bool `json:"truncation,omitempty"`
OutputDimension int64 `json:"output_dimension,omitempty"`
OutputDtype string `json:"output_dtype,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"`
}
type Usage struct {
// The total number of tokens used by the request.
TotalTokens int `json:"total_tokens"`
}
type EmbeddingData struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingData `json:"data"`
Model string `json:"model"`
Usage Usage `json:"usage"`
}
type ByIndex struct {
resp *EmbeddingResponse
}
func (eb *ByIndex) Len() int { return len(eb.resp.Data) }
func (eb *ByIndex) Swap(i, j int) {
eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i]
}
func (eb *ByIndex) Less(i, j int) bool {
return eb.resp.Data[i].Index < eb.resp.Data[j].Index
}
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
type VoyageAIEmbedding struct {
apiKey string
url string
}
func NewVoyageAIEmbeddingClient(apiKey string, url string) *VoyageAIEmbedding {
return &VoyageAIEmbedding{
apiKey: apiKey,
url: url,
}
}
func (c *VoyageAIEmbedding) Check() error {
if c.apiKey == "" {
return fmt.Errorf("api key is empty")
}
if c.url == "" {
return fmt.Errorf("url is empty")
}
return nil
}
func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = modelName
r.Input = texts
r.InputType = textType
r.OutputDtype = outputType
if dim != 0 {
r.OutputDimension = int64(dim)
}
data, err := json.Marshal(r)
if err != nil {
return nil, err
}
if timeoutSec <= 0 {
timeoutSec = utils.DefaultTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
}
sort.Sort(&ByIndex{&res})
return &res, err
}

View File

@ -0,0 +1,127 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package voyageai
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEmbeddingClientCheck(t *testing.T) {
{
c := NewVoyageAIEmbeddingClient("", "mock_uri")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewVoyageAIEmbeddingClient("mock_key", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}
{
c := NewVoyageAIEmbeddingClient("mock_key", "mock_uri")
err := c.Check()
assert.True(t, err == nil)
}
}
func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
repStr := `{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [
0.0,
0.1
],
"index": 0
},
{
"object": "embedding",
"embedding": [
2.0,
2.1
],
"index": 2
},
{
"object": "embedding",
"embedding": [
1.0,
1.1
],
"index": 1
}
],
"model": "voyage-large-2",
"usage": {
"total_tokens": 10
}
}`
err := json.Unmarshal([]byte(repStr), &res)
assert.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
url := ts.URL
{
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
assert.Equal(t, ret.Data[1].Index, 1)
assert.Equal(t, ret.Data[2].Index, 2)
assert.Equal(t, ret.Data[0].Embedding, []float32{0.0, 0.1})
assert.Equal(t, ret.Data[1].Embedding, []float32{1.0, 1.1})
assert.Equal(t, ret.Data[2].Embedding, []float32{2.0, 2.1})
}
}
func TestEmbeddingFailed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer ts.Close()
url := ts.URL
{
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
assert.True(t, err != nil)
}
}

View File

@ -0,0 +1,176 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"os"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/openai"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type OpenAIEmbeddingProvider struct {
fieldDim int64
client openai.OpenAIEmbeddingInterface
modelName string
embedDimParam int64
user string
maxBatch int
timeoutSec int64
}
func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) {
if apiKey == "" {
apiKey = os.Getenv(openaiAKEnvStr)
}
if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr)
}
if url == "" {
url = "https://api.openai.com/v1/embeddings"
}
c := openai.NewOpenAIEmbeddingClient(apiKey, url)
return c, nil
}
func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) {
if apiKey == "" {
apiKey = os.Getenv(azureOpenaiAKEnvStr)
}
if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr)
}
if url == "" {
if resourceName := os.Getenv(azureOpenaiResourceName); resourceName != "" {
url = fmt.Sprintf("https://%s.openai.azure.com", resourceName)
}
}
if url == "" {
return nil, fmt.Errorf("Must configure the %s environment variable in the Milvus service", azureOpenaiResourceName)
}
c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url)
return c, nil
}
func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, isAzure bool) (*OpenAIEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil {
return nil, err
}
var apiKey, url, modelName, user string
var dim int64
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case dimParamKey:
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
if err != nil {
return nil, err
}
case userParamKey:
user = param.Value
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
default:
}
}
var c openai.OpenAIEmbeddingInterface
if !isAzure {
if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large)
}
c, err = createOpenAIEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
}
} else {
c, err = createAzureOpenAIEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
}
}
provider := OpenAIEmbeddingProvider{
client: c,
fieldDim: fieldDim,
modelName: modelName,
user: user,
embedDimParam: dim,
maxBatch: 128,
timeoutSec: 30,
}
return &provider, nil
}
func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) {
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, false)
}
func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) {
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, true)
}
func (provider *OpenAIEmbeddingProvider) MaxBatch() int {
return 5 * provider.maxBatch
}
func (provider *OpenAIEmbeddingProvider) FieldDim() int64 {
return provider.fieldDim
}
func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, _ TextEmbeddingMode) ([][]float32, error) {
numRows := len(texts)
data := make([][]float32, 0, numRows)
for i := 0; i < numRows; i += provider.maxBatch {
end := i + provider.maxBatch
if end > numRows {
end = numRows
}
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), provider.user, provider.timeoutSec)
if err != nil {
return nil, err
}
if end-i != len(resp.Data) {
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data))
}
for _, item := range resp.Data {
if len(item.Embedding) != int(provider.fieldDim) {
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
provider.fieldDim, len(item.Embedding))
}
data = append(data, item.Embedding)
}
}
return data, nil
}

View File

@ -0,0 +1,206 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/openai"
)
func TestOpenAITextEmbeddingProvider(t *testing.T) {
suite.Run(t, new(OpenAITextEmbeddingProviderSuite))
}
type OpenAITextEmbeddingProviderSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *OpenAITextEmbeddingProviderSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
s.providers = []string{openAIProvider, azureOpenAIProvider}
}
func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: dimParamKey, Value: "4"},
{Key: embeddingURLParamKey, Value: url},
},
}
switch providerName {
case openAIProvider:
return NewOpenAIEmbeddingProvider(schema, functionSchema)
case azureOpenAIProvider:
return NewAzureOpenAIEmbeddingProvider(schema, functionSchema)
default:
return nil, fmt.Errorf("Unknow provider")
}
}
func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() {
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
{
data := []string{"sentence"}
ret, err2 := provder.CallEmbedding(data, InsertMode)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(4, len(ret[0]))
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
}
{
data := []string{"sentence 1", "sentence 2", "sentence 3"}
ret, _ := provder.CallEmbedding(data, SearchMode)
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
}
}
}
func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res openai.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
res.Data = append(res.Data, openai.EmbeddingData{
Object: "embedding",
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
Index: 0,
})
res.Data = append(res.Data, openai.EmbeddingData{
Object: "embedding",
Embedding: []float32{1.0, 1.0},
Index: 1,
})
res.Usage = openai.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res openai.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
res.Data = append(res.Data, openai.EmbeddingData{
Object: "embedding",
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
Index: 0,
})
res.Usage = openai.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence2"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *OpenAITextEmbeddingProviderSuite) TestCreateOpenAIEmbeddingClient() {
_, err := createOpenAIEmbeddingClient("", "")
s.Error(err)
os.Setenv(openaiAKEnvStr, "mockKey")
defer os.Unsetenv(openaiAKEnvStr)
_, err = createOpenAIEmbeddingClient("", "")
s.NoError(err)
}
func (s *OpenAITextEmbeddingProviderSuite) TestCreateAzureOpenAIEmbeddingClient() {
_, err := createAzureOpenAIEmbeddingClient("", "")
s.Error(err)
os.Setenv(azureOpenaiAKEnvStr, "mockKey")
defer os.Unsetenv(azureOpenaiAKEnvStr)
_, err = createAzureOpenAIEmbeddingClient("", "")
s.Error(err)
os.Setenv(azureOpenaiResourceName, "mockResource")
defer os.Unsetenv(azureOpenaiResourceName)
_, err = createAzureOpenAIEmbeddingClient("", "")
s.NoError(err)
}

View File

@ -0,0 +1,239 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
const (
Provider string = "provider"
)
const (
openAIProvider string = "openai"
azureOpenAIProvider string = "azure_openai"
aliDashScopeProvider string = "dashscope"
bedrockProvider string = "bedrock"
vertexAIProvider string = "vertexai"
voyageAIProvider string = "voyageai"
)
// Text embedding for retrieval task
type textEmbeddingProvider interface {
MaxBatch() int
CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error)
FieldDim() int64
}
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case Provider:
return strings.ToLower(param.Value), nil
default:
}
}
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
}
type TextEmbeddingFunction struct {
FunctionBase
embProvider textEmbeddingProvider
}
func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmbeddingFunction, error) {
if len(functionSchema.GetOutputFieldNames()) != 1 {
return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames()))
}
base, err := NewFunctionBase(coll, functionSchema)
if err != nil {
return nil, err
}
if base.outputFields[0].DataType != schemapb.DataType_FloatVector {
return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s], got [%s]",
schemapb.DataType_name[int32(schemapb.DataType_FloatVector)],
schemapb.DataType_name[int32(base.outputFields[0].DataType)])
}
provider, err := getProvider(functionSchema)
if err != nil {
return nil, err
}
switch provider {
case openAIProvider:
embP, err := NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
case azureOpenAIProvider:
embP, err := NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
case bedrockProvider:
embP, err := NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
case aliDashScopeProvider:
embP, err := NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
case vertexAIProvider:
embP, err := NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
case voyageAIProvider:
embP, err := NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
default:
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider)
}
}
func (runner *TextEmbeddingFunction) MaxBatch() int {
return runner.embProvider.MaxBatch()
}
func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs))
}
if inputs[0].Type != schemapb.DataType_VarChar {
return nil, fmt.Errorf("Text embedding only supports varchar field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)])
}
texts := inputs[0].GetScalars().GetStringData().GetData()
if texts == nil {
return nil, fmt.Errorf("Input texts is empty")
}
numRows := len(texts)
if numRows > runner.MaxBatch() {
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
}
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
if err != nil {
return nil, err
}
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
for _, emb := range embds {
data = append(data, emb...)
}
var outputField schemapb.FieldData
outputField.FieldId = runner.outputFields[0].FieldID
outputField.FieldName = runner.outputFields[0].Name
outputField.Type = runner.outputFields[0].DataType
outputField.IsDynamic = runner.outputFields[0].IsDynamic
outputField.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: data,
},
},
Dim: runner.embProvider.FieldDim(),
},
}
return []*schemapb.FieldData{&outputField}, nil
}
func (runner *TextEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally
numRows := len(texts)
if numRows > runner.MaxBatch() {
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
}
embds, err := runner.embProvider.CallEmbedding(texts, SearchMode)
if err != nil {
return nil, err
}
return funcutil.Float32VectorsToPlaceholderGroup(embds), nil
}
func (runner *TextEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs))
}
if inputs[0].GetDataType() != schemapb.DataType_VarChar {
return nil, fmt.Errorf(" only supports varchar field, the input is not varchar")
}
texts, ok := inputs[0].GetDataRows().([]string)
if !ok {
return nil, fmt.Errorf("Input texts is empty")
}
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
if err != nil {
return nil, err
}
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
for _, emb := range embds {
data = append(data, emb...)
}
field := &storage.FloatVectorFieldData{
Data: data,
Dim: int(runner.embProvider.FieldDim()),
}
return map[storage.FieldID]storage.FieldData{
runner.outputFields[0].FieldID: field,
}, nil
}

View File

@ -0,0 +1,633 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"strings"
"testing"
"github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
func TestTextEmbeddingFunction(t *testing.T) {
suite.Run(t, new(TextEmbeddingFunctionSuite))
}
type TextEmbeddingFunctionSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
}
func (s *TextEmbeddingFunctionSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
}
func createData(texts []string) []*schemapb.FieldData {
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: texts,
},
},
},
},
}
data = append(data, &f)
return data
}
func (s *TextEmbeddingFunctionSuite) TestInvalidProvider() {
fSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
}
providerName, err := getProvider(fSchema)
s.Equal(providerName, openAIProvider)
s.NoError(err)
fSchema.Params = []*commonpb.KeyValuePair{}
providerName, err = getProvider(fSchema)
s.Equal(providerName, "")
s.Error(err)
}
func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
{
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
},
})
s.NoError(err)
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.ProcessInsert(data)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
}
}
{
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: azureOpenAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
},
})
s.NoError(err)
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.ProcessInsert(data)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
}
}
}
func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
ts := CreateAliEmbeddingServer()
defer ts.Close()
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: aliDashScopeProvider},
{Key: modelNameParamKey, Value: TextEmbeddingV3},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
},
})
s.NoError(err)
{
data := createData([]string{"sentence"})
ret, err2 := runner.ProcessInsert(data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.ProcessInsert(data)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
}
// multi-input
{
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{},
},
},
},
},
}
data = append(data, &f)
data = append(data, &f)
_, err := runner.ProcessInsert(data)
s.Error(err)
}
// wrong input data type
{
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_Int32,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{},
},
}
data = append(data, &f)
_, err := runner.ProcessInsert(data)
s.Error(err)
}
// empty input
{
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{},
},
}
data = append(data, &f)
_, err := runner.ProcessInsert(data)
s.Error(err)
}
// large input data
{
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strings.Split(strings.Repeat("Element,", 1000), ","),
},
},
},
},
}
data = append(data, &f)
_, err := runner.ProcessInsert(data)
s.Error(err)
}
}
func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
// outputfield datatype mismatch
{
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
_, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
})
s.Error(err)
}
// outputfield number mismatc
{
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
_, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector", "vector2"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102, 103},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
})
s.Error(err)
}
// outputfield miss
{
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector2"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{103},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
})
s.Error(err)
}
// error model name
{
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-004"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
},
})
s.Error(err)
}
// no openai api key
{
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-003"},
},
})
s.Error(err)
}
}
func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
{
fSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: bedrockProvider},
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: awsAKIdParamKey, Value: "mock"},
{Key: awsSAKParamKey, Value: "mock"},
{Key: regionParamKey, Value: "mock"},
},
}
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
s.NoError(err)
fSchema.Params = []*commonpb.KeyValuePair{}
_, err = NewTextEmbeddingFunction(s.schema, fSchema)
s.Error(err)
}
{
fSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: aliDashScopeProvider},
{Key: modelNameParamKey, Value: TextEmbeddingV1},
{Key: apiKeyParamKey, Value: "mock"},
},
}
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
s.NoError(err)
fSchema.Params = []*commonpb.KeyValuePair{}
_, err = NewTextEmbeddingFunction(s.schema, fSchema)
s.Error(err)
}
{
fSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: voyageAIProvider},
{Key: modelNameParamKey, Value: voyage3},
{Key: apiKeyParamKey, Value: "mock"},
},
}
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
s.NoError(err)
fSchema.Params = []*commonpb.KeyValuePair{}
_, err = NewTextEmbeddingFunction(s.schema, fSchema)
s.Error(err)
}
// Invalid params
{
fSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{},
}
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
s.Error(err)
}
{
fSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: "unkownProvider"},
},
}
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
s.Error(err)
}
}
func (s *TextEmbeddingFunctionSuite) TestProcessSearch() {
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
},
})
s.NoError(err)
// Large inputs
{
f := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strings.Split(strings.Repeat("Element,", 1000), ","),
},
},
},
},
}
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
s.NoError(err)
placeholderGroup := commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
_, err = runner.ProcessSearch(&placeholderGroup)
s.Error(err)
}
// Normal inputs
{
f := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strings.Split(strings.Repeat("Element,", 100), ","),
},
},
},
},
}
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
s.NoError(err)
placeholderGroup := commonpb.PlaceholderGroup{}
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
_, err = runner.ProcessSearch(&placeholderGroup)
s.NoError(err)
}
}
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsert() {
ts := CreateOpenAIEmbeddingServer()
defer ts.Close()
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: openAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: dimParamKey, Value: "4"},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: ts.URL},
},
})
s.NoError(err)
data, err := testutil.CreateInsertData(s.schema, 100)
s.NoError(err)
{
input := []storage.FieldData{data.Data[101]}
_, err := runner.ProcessBulkInsert(input)
s.NoError(err)
}
// Multi-input
{
input := []storage.FieldData{data.Data[101], data.Data[101]}
_, err := runner.ProcessBulkInsert(input)
s.Error(err)
}
// Error input type
{
input := []storage.FieldData{data.Data[102]}
_, err := runner.ProcessBulkInsert(input)
s.Error(err)
}
}

View File

@ -0,0 +1,211 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"os"
"strings"
"sync"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type vertexAIJsonKey struct {
jsonKey []byte
once sync.Once
initErr error
}
var vtxKey vertexAIJsonKey
func getVertexAIJsonKey() ([]byte, error) {
vtxKey.once.Do(func() {
jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv)
jsonKey, err := os.ReadFile(jsonKeyPath)
if err != nil {
vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err)
return
}
vtxKey.jsonKey = jsonKey
})
return vtxKey.jsonKey, vtxKey.initErr
}
const (
vertexAIDocRetrival string = "DOC_RETRIEVAL"
vertexAICodeRetrival string = "CODE_RETRIEVAL"
vertexAISTS string = "STS"
)
func checkTask(modelName string, task string) error {
if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS {
return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS)
}
if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival {
return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival)
}
return nil
}
type VertexAIEmbeddingProvider struct {
fieldDim int64
client *vertexai.VertexAIEmbedding
modelName string
embedDimParam int64
task string
maxBatch int
timeoutSec int64
}
func createVertexAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) {
jsonKey, err := getVertexAIJsonKey()
if err != nil {
return nil, err
}
c := vertexai.NewVertexAIEmbedding(url, jsonKey, "https://www.googleapis.com/auth/cloud-platform", "")
return c, nil
}
func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertexAIEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil {
return nil, err
}
var location, projectID, task, modelName string
var dim int64
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case dimParamKey:
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
if err != nil {
return nil, err
}
case locationParamKey:
location = param.Value
case projectIDParamKey:
projectID = param.Value
case taskTypeParamKey:
task = param.Value
default:
}
}
if task == "" {
task = vertexAIDocRetrival
}
if err := checkTask(modelName, task); err != nil {
return nil, err
}
if location == "" {
location = "us-central1"
}
if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]",
modelName, textEmbedding005, textMultilingualEmbedding002)
}
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName)
var client *vertexai.VertexAIEmbedding
if c == nil {
client, err = createVertexAIEmbeddingClient(url)
if err != nil {
return nil, err
}
} else {
client = c
}
provider := VertexAIEmbeddingProvider{
fieldDim: fieldDim,
client: client,
modelName: modelName,
embedDimParam: dim,
task: task,
maxBatch: 128,
timeoutSec: 30,
}
return &provider, nil
}
func (provider *VertexAIEmbeddingProvider) MaxBatch() int {
return 5 * provider.maxBatch
}
func (provider *VertexAIEmbeddingProvider) FieldDim() int64 {
return provider.fieldDim
}
func (provider *VertexAIEmbeddingProvider) getTaskType(mode TextEmbeddingMode) string {
if mode == SearchMode {
switch provider.task {
case vertexAIDocRetrival:
return "RETRIEVAL_QUERY"
case vertexAICodeRetrival:
return "CODE_RETRIEVAL_QUERY"
case vertexAISTS:
return "SEMANTIC_SIMILARITY"
}
} else {
switch provider.task {
case vertexAIDocRetrival:
return "RETRIEVAL_DOCUMENT"
case vertexAICodeRetrival: // When inserting, the model does not distinguish between doc and code
return "RETRIEVAL_DOCUMENT"
case vertexAISTS:
return "SEMANTIC_SIMILARITY"
}
}
return ""
}
func (provider *VertexAIEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
numRows := len(texts)
taskType := provider.getTaskType(mode)
data := make([][]float32, 0, numRows)
for i := 0; i < numRows; i += provider.maxBatch {
end := i + provider.maxBatch
if end > numRows {
end = numRows
}
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, provider.timeoutSec)
if err != nil {
return nil, err
}
if end-i != len(resp.Predictions) {
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Predictions))
}
for _, item := range resp.Predictions {
if len(item.Embeddings.Values) != int(provider.fieldDim) {
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
provider.fieldDim, len(item.Embeddings.Values))
}
data = append(data, item.Embeddings.Values)
}
}
return data, nil
}

View File

@ -0,0 +1,276 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
)
func TestVertexAITextEmbeddingProvider(t *testing.T) {
suite.Run(t, new(VertexAITextEmbeddingProviderSuite))
}
type VertexAITextEmbeddingProviderSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *VertexAITextEmbeddingProviderSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
}
func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: textEmbedding005},
{Key: locationParamKey, Value: "mock_local"},
{Key: projectIDParamKey, Value: "mock_id"},
{Key: taskTypeParamKey, Value: vertexAICodeRetrival},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "4"},
},
}
mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token")
return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient)
}
func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() {
ts := CreateVertexAIEmbeddingServer()
defer ts.Close()
provder, err := createVertexAIProvider(ts.URL, s.schema.Fields[2])
s.NoError(err)
{
data := []string{"sentence"}
ret, err2 := provder.CallEmbedding(data, InsertMode)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(4, len(ret[0]))
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
}
{
data := []string{"sentence 1", "sentence 2", "sentence 3"}
ret, _ := provder.CallEmbedding(data, SearchMode)
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
}
}
func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res vertexai.EmbeddingResponse
res.Predictions = append(res.Predictions, vertexai.Prediction{
Embeddings: vertexai.Embeddings{
Statistics: vertexai.Statistics{
Truncated: false,
TokenCount: 10,
},
Values: []float32{1.0, 1.0, 1.0, 1.0},
},
})
res.Predictions = append(res.Predictions, vertexai.Prediction{
Embeddings: vertexai.Embeddings{
Statistics: vertexai.Statistics{
Truncated: false,
TokenCount: 10,
},
Values: []float32{1.0, 1.0},
},
})
res.Metadata = vertexai.Metadata{
BillableCharacterCount: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
provder, err := createVertexAIProvider(ts.URL, s.schema.Fields[2])
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res vertexai.EmbeddingResponse
res.Predictions = append(res.Predictions, vertexai.Prediction{
Embeddings: vertexai.Embeddings{
Statistics: vertexai.Statistics{
Truncated: false,
TokenCount: 10,
},
Values: []float32{1.0, 1.0, 1.0, 1.0},
},
})
res.Metadata = vertexai.Metadata{
BillableCharacterCount: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
provder, err := createVertexAIProvider(ts.URL, s.schema.Fields[2])
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence2"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
func (s *VertexAITextEmbeddingProviderSuite) TestCheckVertexAITask() {
err := checkTask(textMultilingualEmbedding002, "UnkownTask")
s.Error(err)
// textMultilingualEmbedding002 not support vertexAICodeRetrival task
err = checkTask(textMultilingualEmbedding002, vertexAICodeRetrival)
s.Error(err)
err = checkTask(textEmbedding005, vertexAICodeRetrival)
s.NoError(err)
err = checkTask(textMultilingualEmbedding002, vertexAISTS)
s.NoError(err)
}
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
defer os.Unsetenv(vertexServiceAccountJSONEnv)
_, err := getVertexAIJsonKey()
s.Error(err)
}
func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: textEmbedding005},
{Key: projectIDParamKey, Value: "mock_id"},
{Key: dimParamKey, Value: "4"},
},
}
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
{
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.NoError(err)
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY")
}
{
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival})
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.NoError(err)
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY")
}
{
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS}
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.NoError(err)
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
}
// invalid task
{
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: "UnkownTask"}
_, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.Error(err)
}
}
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
defer os.Unsetenv(vertexServiceAccountJSONEnv)
_, err := createVertexAIEmbeddingClient("https://mock_url.com")
s.Error(err)
}
func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: textEmbedding005},
{Key: projectIDParamKey, Value: "mock_id"},
{Key: dimParamKey, Value: "4"},
},
}
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.NoError(err)
s.True(provider.MaxBatch() > 0)
s.Equal(provider.FieldDim(), int64(4))
// check model name
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
_, err = NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
s.Error(err)
}

View File

@ -0,0 +1,152 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"fmt"
"os"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type VoyageAIEmbeddingProvider struct {
fieldDim int64
client *voyageai.VoyageAIEmbedding
modelName string
embedDimParam int64
maxBatch int
timeoutSec int64
}
func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) {
if apiKey == "" {
apiKey = os.Getenv(voyageAIAKEnvStr)
}
if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr)
}
if url == "" {
url = "https://api.voyageai.com/v1/embeddings"
}
c := voyageai.NewVoyageAIEmbeddingClient(apiKey, url)
return c, nil
}
func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*VoyageAIEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil {
return nil, err
}
var apiKey, url, modelName string
dim := int64(0)
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case dimParamKey:
// Only voyage-3-large and voyage-code-3 support dim param: 1024 (default), 256, 512, 2048
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
if err != nil {
return nil, err
}
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
default:
}
}
if modelName != voyage3Large && modelName != voyage3 && modelName != voyage3Lite && modelName != voyageCode3 && modelName != voyageFinance2 && modelName != voyageLaw2 && modelName != voyageCode2 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
modelName, voyage3Large, voyage3, voyage3Lite, voyageCode3, voyageFinance2, voyageLaw2, voyageCode2)
}
if dim != 0 {
if modelName != voyage3Large && modelName != voyageCode3 {
return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports dim parameter, only [%s, %s] support it.", modelName, voyage3, voyageCode3)
}
if dim != 1024 && dim != 256 && dim != 512 && dim != 2048 {
return nil, fmt.Errorf("VoyageAI text embedding model's dim only supports 2048, 1024 (default), 512, and 256.")
}
}
c, err := createVoyageAIEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
}
provider := VoyageAIEmbeddingProvider{
client: c,
fieldDim: fieldDim,
modelName: modelName,
embedDimParam: dim,
maxBatch: 128,
timeoutSec: 30,
}
return &provider, nil
}
func (provider *VoyageAIEmbeddingProvider) MaxBatch() int {
return 5 * provider.maxBatch
}
func (provider *VoyageAIEmbeddingProvider) FieldDim() int64 {
return provider.fieldDim
}
func (provider *VoyageAIEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
numRows := len(texts)
var textType string
if mode == InsertMode {
textType = "document"
} else {
textType = "query"
}
data := make([][]float32, 0, numRows)
for i := 0; i < numRows; i += provider.maxBatch {
end := i + provider.maxBatch
if end > numRows {
end = numRows
}
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, "float", provider.timeoutSec)
if err != nil {
return nil, err
}
if end-i != len(resp.Data) {
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data))
}
for _, item := range resp.Data {
if len(item.Embedding) != int(provider.fieldDim) {
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
provider.fieldDim, len(item.Embedding))
}
data = append(data, item.Embedding)
}
}
return data, nil
}

View File

@ -0,0 +1,221 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package function
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
)
func TestVoyageAITextEmbeddingProvider(t *testing.T) {
suite.Run(t, new(VoyageAITextEmbeddingProviderSuite))
}
type VoyageAITextEmbeddingProviderSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}
func (s *VoyageAITextEmbeddingProviderSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "1024"},
},
},
},
}
s.providers = []string{voyageAIProvider}
}
func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: voyage3Large},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
{Key: dimParamKey, Value: "1024"},
},
}
switch providerName {
case voyageAIProvider:
return NewVoyageAIEmbeddingProvider(schema, functionSchema)
default:
return nil, fmt.Errorf("Unknow provider")
}
}
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbedding() {
ts := CreateVoyageAIEmbeddingServer()
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
{
data := []string{"sentence"}
ret, err2 := provder.CallEmbedding(data, InsertMode)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(1024, len(ret[0]))
}
{
data := []string{"sentence 1", "sentence 2", "sentence 3"}
_, err := provder.CallEmbedding(data, SearchMode)
s.NoError(err)
}
}
}
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res voyageai.EmbeddingResponse
res.Data = append(res.Data, voyageai.EmbeddingData{
Object: "list",
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
Index: 0,
})
res.Data = append(res.Data, voyageai.EmbeddingData{
Object: "list",
Embedding: []float32{1.0, 1.0},
Index: 1,
})
res.Usage = voyageai.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
for _, providerName := range s.providers {
provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], providerName)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res voyageai.EmbeddingResponse
res.Data = append(res.Data, voyageai.EmbeddingData{
Object: "list",
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
Index: 0,
})
res.Usage = voyageai.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
for _, provderName := range s.providers {
provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
// embedding dim not match
data := []string{"sentence", "sentence2"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}
func (s *VoyageAITextEmbeddingProviderSuite) TestCreateVoyageAIEmbeddingClient() {
_, err := createVoyageAIEmbeddingClient("", "")
s.Error(err)
os.Setenv(voyageAIAKEnvStr, "mockKey")
defer os.Unsetenv(voyageAIAKEnvStr)
_, err = createVoyageAIEmbeddingClient("", "")
s.NoError(err)
}
func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: voyage3Large},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: "mock"},
{Key: dimParamKey, Value: "1024"},
},
}
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.NoError(err)
s.Equal(provider.FieldDim(), int64(1024))
s.True(provider.MaxBatch() > 0)
// Invalid model
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}
// Invalid dim
{
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}
// Invalid dim type
{
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"}
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}
}

View File

@ -25,6 +25,21 @@ func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte {
return bytes
}
func Float32VectorsToPlaceholderGroup(embs [][]float32) *commonpb.PlaceholderGroup {
result := make([][]byte, 0, len(embs))
for _, floatVector := range embs {
result = append(result, floatVectorToByteVector(floatVector))
}
placeholderGroup := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: result,
}},
}
return placeholderGroup
}
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
if err != nil {