mirror of https://github.com/milvus-io/milvus.git
feat: Add Text Embedding Function (#36366)
https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>pull/39580/head
parent
47d280d974
commit
16cbdfb3b1
15
go.mod
15
go.mod
|
@ -59,6 +59,10 @@ require (
|
|||
require (
|
||||
cloud.google.com/go/storage v1.43.0
|
||||
github.com/antlr4-go/antlr/v4 v4.13.1
|
||||
github.com/aws/aws-sdk-go-v2 v1.32.6
|
||||
github.com/aws/aws-sdk-go-v2/config v1.28.6
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.47
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0
|
||||
github.com/bits-and-blooms/bitset v1.10.0
|
||||
github.com/bytedance/sonic v1.12.2
|
||||
github.com/cenkalti/backoff/v4 v4.2.1
|
||||
|
@ -101,6 +105,17 @@ require (
|
|||
github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7 // indirect
|
||||
github.com/apache/thrift v0.18.1 // indirect
|
||||
github.com/ardielle/ardielle-go v1.5.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect
|
||||
github.com/aws/smithy-go v1.22.1 // indirect
|
||||
github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.0 // indirect
|
||||
|
|
30
go.sum
30
go.sum
|
@ -120,6 +120,36 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5
|
|||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
|
||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||
github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
|
||||
github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 h1:mfV5tcLXeRLbiyI4EHoHWH1sIU7JvbfXVvymUCIgZEo=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0/go.mod h1:YSSgYnasDKm5OjU3bOPkaz+2PFO6WjEQGIA6KQNsR3Q=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8=
|
||||
github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro=
|
||||
github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
|
||||
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b h1:5JgaFtHFRnOPReItxvhMDXbvuBkjSWE+9glJyF466yw=
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/importutilv2"
|
||||
"github.com/milvus-io/milvus/internal/util/testutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
@ -435,6 +436,107 @@ func (s *SchedulerSuite) TestScheduler_ImportFile() {
|
|||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
|
||||
s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) (*conc.Future[struct{}], error) {
|
||||
future := conc.Go(func() (struct{}, error) {
|
||||
return struct{}{}, nil
|
||||
})
|
||||
return future, nil
|
||||
})
|
||||
ts := function.CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.MaxLengthKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "4",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "int64",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{100},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{101},
|
||||
OutputFieldNames: []string{"vec"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
data, err := testutil.CreateInsertData(schema, s.numRows)
|
||||
s.NoError(err)
|
||||
s.reader = importutilv2.NewMockReader(s.T())
|
||||
s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) {
|
||||
var res *storage.InsertData
|
||||
once.Do(func() {
|
||||
res = data
|
||||
})
|
||||
if res != nil {
|
||||
return res, nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
})
|
||||
importReq := &datapb.ImportRequest{
|
||||
JobID: 10,
|
||||
TaskID: 11,
|
||||
CollectionID: 12,
|
||||
PartitionIDs: []int64{13},
|
||||
Vchannels: []string{"v0"},
|
||||
Schema: schema,
|
||||
Files: []*internalpb.ImportFile{
|
||||
{
|
||||
Paths: []string{"dummy.json"},
|
||||
},
|
||||
},
|
||||
Ts: 1000,
|
||||
IDRange: &datapb.IDRange{
|
||||
Begin: 0,
|
||||
End: int64(s.numRows),
|
||||
},
|
||||
RequestSegments: []*datapb.ImportRequestSegment{
|
||||
{
|
||||
SegmentID: 14,
|
||||
PartitionID: 13,
|
||||
Vchannel: "v0",
|
||||
},
|
||||
},
|
||||
}
|
||||
importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm)
|
||||
s.manager.Add(importTask)
|
||||
err = importTask.(*ImportTask).importFile(s.reader)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func TestScheduler(t *testing.T) {
|
||||
suite.Run(t, new(SchedulerSuite))
|
||||
}
|
||||
|
|
|
@ -199,12 +199,39 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData) error {
|
|||
}
|
||||
|
||||
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
|
||||
if err := RunBm25Function(task, data); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := RunDenseEmbedding(task, data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error {
|
||||
schema := task.GetSchema()
|
||||
if function.HasNonBM25Functions(schema.Functions, []int64{}) {
|
||||
exec, err := function.NewFunctionExecutor(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessBulkInsert(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunBm25Function(task *ImportTask, data *storage.InsertData) error {
|
||||
fns := task.GetSchema().GetFunctions()
|
||||
for _, fn := range fns {
|
||||
runner, err := function.NewFunctionRunner(task.GetSchema(), fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if runner == nil {
|
||||
continue
|
||||
}
|
||||
inputDatas := make([]any, 0, len(fn.InputFieldIds))
|
||||
for _, inputFieldID := range fn.InputFieldIds {
|
||||
inputDatas = append(inputDatas, data.Data[inputFieldID].GetDataRows())
|
||||
|
|
|
@ -1099,8 +1099,8 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
|
|||
|
||||
if vectorField.GetIsFunctionOutput() {
|
||||
for _, function := range collSchema.Functions {
|
||||
if function.Type == schemapb.FunctionType_BM25 {
|
||||
// TODO: currently only BM25 function is supported, thus guarantees one input field to one output field
|
||||
if function.Type == schemapb.FunctionType_BM25 || function.Type == schemapb.FunctionType_TextEmbedding {
|
||||
// TODO: currently only BM25 & text embedding function is supported, thus guarantees one input field to one output field
|
||||
if function.OutputFieldNames[0] == vectorField.Name {
|
||||
dataType = schemapb.DataType_VarChar
|
||||
}
|
||||
|
|
|
@ -67,6 +67,9 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if functionRunner == nil {
|
||||
continue
|
||||
}
|
||||
node.functionRunners[tf.GetId()] = functionRunner
|
||||
}
|
||||
return node, nil
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
|
@ -141,6 +142,16 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
it.schema = schema.CollectionSchema
|
||||
|
||||
// Calculate embedding fields
|
||||
if function.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
|
||||
exec, err := function.NewFunctionExecutor(schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessInsert(it.insertMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
rowNums := uint32(it.insertMsg.NRows())
|
||||
// set insertTask.rowIDs
|
||||
var rowIDBegin UniqueID
|
||||
|
|
|
@ -10,7 +10,11 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/testutils"
|
||||
|
@ -308,3 +312,128 @@ func TestMaxInsertSize(t *testing.T) {
|
|||
assert.ErrorIs(t, err, merr.ErrParameterTooLarge)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertTask_Function(t *testing.T) {
|
||||
ts := function.CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
FieldName: "text",
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"sentence", "sentence"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
collectionName := "TestInsertTask_function"
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "TestInsertTask_function",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
|
||||
{
|
||||
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "200"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
IsFunctionOutput: true,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "test_function",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{101},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{102},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
|
||||
Status: merr.Status(nil),
|
||||
ID: 11198,
|
||||
Count: 10,
|
||||
}, nil)
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
|
||||
idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
assert.NoError(t, err)
|
||||
task := insertTask{
|
||||
ctx: context.Background(),
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
CollectionName: collectionName,
|
||||
DbName: "hooooooo",
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
FieldsData: data,
|
||||
NumRows: 2,
|
||||
},
|
||||
},
|
||||
schema: schema,
|
||||
idAllocator: idAllocator,
|
||||
}
|
||||
|
||||
info := newSchemaInfo(schema)
|
||||
cache := NewMockCache(t)
|
||||
cache.On("GetCollectionSchema",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(info, nil)
|
||||
|
||||
cache.On("GetPartitionInfo",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(&partitionInfo{
|
||||
name: "p1",
|
||||
partitionID: 10,
|
||||
createdTimestamp: 10001,
|
||||
createdUtcTimestamp: 10002,
|
||||
}, nil)
|
||||
cache.On("GetCollectionInfo",
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(&collectionInfo{schema: info}, nil)
|
||||
cache.On("GetDatabaseInfo",
|
||||
mock.Anything,
|
||||
mock.Anything,
|
||||
).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil)
|
||||
|
||||
globalMetaCache = cache
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
|
@ -362,6 +363,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
// fetch search_growing from search param
|
||||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||
queryFieldIds := []int64{}
|
||||
for index, subReq := range t.request.GetSubReqs() {
|
||||
plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), subReq.GetExprTemplateValues())
|
||||
if err != nil {
|
||||
|
@ -383,6 +385,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
}
|
||||
|
||||
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
||||
queryFieldIds = append(queryFieldIds, internalSubReq.FieldId)
|
||||
// set PartitionIDs for sub search
|
||||
if t.partitionKeyMode {
|
||||
// isolation has tighter constraint, check first
|
||||
|
@ -421,6 +424,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
}
|
||||
|
||||
var err error
|
||||
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIds) {
|
||||
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId()
|
||||
t.SearchRequest.GroupSize = t.rankParams.GetGroupSize()
|
||||
|
||||
|
@ -428,7 +442,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
if t.partitionKeyMode {
|
||||
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
|
||||
}
|
||||
var err error
|
||||
|
||||
t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||
|
@ -499,6 +513,16 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId
|
||||
t.SearchRequest.GroupSize = queryInfo.GroupSize
|
||||
|
||||
if function.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) {
|
||||
exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessSearch(t.SearchRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Debug("proxy init search request",
|
||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -39,6 +40,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/proto/internalpb"
|
||||
|
@ -499,6 +501,251 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
ts := function.CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
collectionName := "TestSearchTask_function"
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "TestSearchTask_function",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
|
||||
{
|
||||
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "200"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "func1",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{101},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{102},
|
||||
OutputFieldNames: []string{"vector1"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "func2",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{101},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{103},
|
||||
OutputFieldNames: []string{"vector2"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
var (
|
||||
rc = NewRootCoordMock()
|
||||
qc = mocks.NewMockQueryCoordClient(t)
|
||||
ctx = context.TODO()
|
||||
)
|
||||
|
||||
defer rc.Close()
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
|
||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSearchTask := func(t *testing.T, collName string, data []string) *searchTask {
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_VarChar,
|
||||
Values: lo.Map(data, func(str string, _ int) []byte { return []byte(str) }),
|
||||
}
|
||||
holder := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
|
||||
}
|
||||
holderByte, _ := proto.Marshal(holder)
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collectionName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
Nq: int64(len(data)),
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: "vector1"},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
PlaceholderGroup: holderByte,
|
||||
},
|
||||
qc: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
collectionID := UniqueID(1000)
|
||||
cache := NewMockCache(t)
|
||||
info := newSchemaInfo(schema)
|
||||
cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe()
|
||||
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe()
|
||||
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
|
||||
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
|
||||
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
|
||||
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
globalMetaCache = cache
|
||||
|
||||
{
|
||||
task := getSearchTask(t, collectionName, []string{"sentence"})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb)
|
||||
assert.Equal(t, len(pb.Placeholders), 1)
|
||||
assert.Equal(t, len(pb.Placeholders[0].Values), 1)
|
||||
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
|
||||
}
|
||||
|
||||
{
|
||||
task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb)
|
||||
assert.Equal(t, len(pb.Placeholders), 1)
|
||||
assert.Equal(t, len(pb.Placeholders[0].Values), 2)
|
||||
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
|
||||
}
|
||||
|
||||
// process failed
|
||||
{
|
||||
task := getSearchTask(t, collectionName, []string{"sentence"})
|
||||
task.request.Nq = 10000
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
getHybridSearchTask := func(t *testing.T, collName string, data [][]string) *searchTask {
|
||||
subReqs := []*milvuspb.SubSearchRequest{}
|
||||
for _, item := range data {
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_VarChar,
|
||||
Values: lo.Map(item, func(str string, _ int) []byte { return []byte(str) }),
|
||||
}
|
||||
holder := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
|
||||
}
|
||||
holderByte, _ := proto.Marshal(holder)
|
||||
subReq := &milvuspb.SubSearchRequest{
|
||||
PlaceholderGroup: holderByte,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: "vector1"},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
},
|
||||
Nq: int64(len(item)),
|
||||
}
|
||||
subReqs = append(subReqs, subReq)
|
||||
}
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: collectionName,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
Timestamp: uint64(time.Now().UnixNano()),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
SubReqs: subReqs,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: LimitKey, Value: "10"},
|
||||
},
|
||||
},
|
||||
qc: qc,
|
||||
tr: timerecord.NewTimeRecorder("test-search"),
|
||||
}
|
||||
require.NoError(t, task.OnEnqueue())
|
||||
return task
|
||||
}
|
||||
|
||||
{
|
||||
task := getHybridSearchTask(t, collectionName, [][]string{
|
||||
{"sentence1"},
|
||||
{"sentence2"},
|
||||
})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(task.SearchRequest.SubReqs), 2)
|
||||
for _, sub := range task.SearchRequest.SubReqs {
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(sub.PlaceholderGroup, pb)
|
||||
assert.Equal(t, len(pb.Placeholders), 1)
|
||||
assert.Equal(t, len(pb.Placeholders[0].Values), 1)
|
||||
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
task := getHybridSearchTask(t, collectionName, [][]string{
|
||||
{"sentence1", "sentence1"},
|
||||
{"sentence2", "sentence2"},
|
||||
{"sentence3", "sentence3"},
|
||||
})
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(task.SearchRequest.SubReqs), 3)
|
||||
for _, sub := range task.SearchRequest.SubReqs {
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(sub.PlaceholderGroup, pb)
|
||||
assert.Equal(t, len(pb.Placeholders), 1)
|
||||
assert.Equal(t, len(pb.Placeholders[0].Values), 2)
|
||||
assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector)
|
||||
}
|
||||
}
|
||||
// process failed
|
||||
{
|
||||
task := getHybridSearchTask(t, collectionName, [][]string{
|
||||
{"sentence1", "sentence1"},
|
||||
{"sentence2", "sentence2"},
|
||||
{"sentence3", "sentence3"},
|
||||
})
|
||||
task.request.SubReqs[0].Nq = 10000
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func getQueryCoord() *mocks.MockQueryCoord {
|
||||
qc := &mocks.MockQueryCoord{}
|
||||
qc.EXPECT().Start().Return(nil)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
@ -1022,6 +1023,47 @@ func TestCreateCollectionTask(t *testing.T) {
|
|||
err = task2.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("collection with embedding function ", func(t *testing.T) {
|
||||
fmt.Println(schema)
|
||||
schema.Functions = []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{varCharField},
|
||||
OutputFieldNames: []string{floatVecField},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task2 := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
err = task2.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = task2.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasCollectionTask(t *testing.T) {
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
|
@ -152,6 +153,16 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Calculate embedding fields
|
||||
if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
|
||||
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
|
||||
// set upsertTask.insertRequest.rowIDs
|
||||
tr := timerecord.NewTimeRecorder("applyPK")
|
||||
|
|
|
@ -27,8 +27,13 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/testutils"
|
||||
)
|
||||
|
||||
|
@ -360,3 +365,133 @@ func TestUpsertTaskForReplicate(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpsertTask_Function(t *testing.T) {
|
||||
ts := function.CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
data := []*schemapb.FieldData{}
|
||||
f1 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldId: 100,
|
||||
FieldName: "id",
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{0, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f1)
|
||||
f2 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
FieldName: "text",
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"sentence", "sentence"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f2)
|
||||
collectionName := "TestUpsertTask_function"
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "TestUpsertTask_function",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
|
||||
{
|
||||
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "200"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
IsFunctionOutput: true,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "test_function",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{101},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{102},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: ts.URL},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := newSchemaInfo(schema)
|
||||
collectionID := UniqueID(0)
|
||||
cache := NewMockCache(t)
|
||||
globalMetaCache = cache
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
|
||||
Status: merr.Status(nil),
|
||||
ID: collectionID,
|
||||
Count: 10,
|
||||
}, nil)
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
|
||||
idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
assert.NoError(t, err)
|
||||
task := upsertTask{
|
||||
ctx: context.Background(),
|
||||
req: &milvuspb.UpsertRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
),
|
||||
CollectionName: collectionName,
|
||||
DbName: "hooooooo",
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
FieldsData: data,
|
||||
NumRows: 2,
|
||||
PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(),
|
||||
},
|
||||
},
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
schema: info,
|
||||
result: &milvuspb.MutationResult{},
|
||||
}
|
||||
err = task.insertPreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// process failed
|
||||
{
|
||||
oldRows := task.upsertMsg.InsertMsg.InsertRequest.NumRows
|
||||
task.upsertMsg.InsertMsg.InsertRequest.NumRows = 10000
|
||||
err = task.insertPreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
task.upsertMsg.InsertMsg.InsertRequest.NumRows = oldRows
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
@ -705,6 +706,10 @@ func validateFunction(coll *schemapb.CollectionSchema) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := function.ValidateFunctions(coll); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -718,6 +723,10 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem
|
|||
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
|
||||
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
|
||||
}
|
||||
case schemapb.FunctionType_TextEmbedding:
|
||||
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_FloatVector {
|
||||
return fmt.Errorf("TextEmbedding function output field must be a FloatVector field")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("check output field for unknown function type")
|
||||
}
|
||||
|
@ -744,7 +753,10 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema
|
|||
if !h.EnableAnalyzer() {
|
||||
return fmt.Errorf("BM25 function input field must set enable_analyzer to true")
|
||||
}
|
||||
|
||||
case schemapb.FunctionType_TextEmbedding:
|
||||
if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar {
|
||||
return fmt.Errorf("TextEmbedding function input field must be a VARCHAR field")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("check input field with unknown function type")
|
||||
}
|
||||
|
@ -786,6 +798,10 @@ func checkFunctionBasicParams(function *schemapb.FunctionSchema) error {
|
|||
if len(function.GetParams()) != 0 {
|
||||
return fmt.Errorf("BM25 function accepts no params")
|
||||
}
|
||||
case schemapb.FunctionType_TextEmbedding:
|
||||
if len(function.GetParams()) == 0 {
|
||||
return fmt.Errorf("TextEmbedding function need provider and model_name params")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("check function params with unknown function type")
|
||||
}
|
||||
|
@ -942,7 +958,7 @@ func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb
|
|||
expectColumnNum := 0
|
||||
for _, field := range schema.GetFields() {
|
||||
fieldName2Schema[field.Name] = field
|
||||
if !field.GetIsFunctionOutput() {
|
||||
if !IsBM25FunctionOutputField(field, schema) {
|
||||
expectColumnNum++
|
||||
}
|
||||
}
|
||||
|
@ -1494,12 +1510,12 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst
|
|||
if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey {
|
||||
return merr.WrapErrParameterInvalidMsg("primary key can't be with default value")
|
||||
}
|
||||
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() {
|
||||
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema, schema) {
|
||||
// when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false
|
||||
autoGenFieldNum++
|
||||
}
|
||||
if _, ok := dataNameSet[fieldSchema.GetName()]; !ok {
|
||||
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() {
|
||||
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema, schema) {
|
||||
// autoGenField
|
||||
continue
|
||||
}
|
||||
|
@ -1523,7 +1539,6 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst
|
|||
zap.Int64("primaryKeyNum", int64(primaryKeyNum)))
|
||||
return merr.WrapErrParameterInvalidMsg("more than 1 primary keys not supported, got %d", primaryKeyNum)
|
||||
}
|
||||
|
||||
expectedNum := len(schema.Fields)
|
||||
actualNum := len(insertMsg.FieldsData) + autoGenFieldNum
|
||||
|
||||
|
@ -2207,3 +2222,21 @@ func GetReplicateID(ctx context.Context, database, collectionName string) (strin
|
|||
replicateID, _ := common.GetReplicateID(dbInfo.properties)
|
||||
return replicateID, nil
|
||||
}
|
||||
|
||||
func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb.CollectionSchema) bool {
|
||||
if !(field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, fSchema := range collSchema.Functions {
|
||||
if fSchema.Type == schemapb.FunctionType_BM25 {
|
||||
if len(fSchema.OutputFieldNames) != 0 && field.Name == fSchema.OutputFieldNames[0] {
|
||||
return true
|
||||
}
|
||||
if len(fSchema.OutputFieldIds) != 0 && field.FieldID == fSchema.OutputFieldIds[0] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -2848,6 +2848,79 @@ func TestValidateFunction(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestValidateModelFunction(t *testing.T) {
|
||||
t.Run("Valid model function schema", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
|
||||
{
|
||||
Name: "output_dense_field", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
{
|
||||
Name: "text_embedding_func",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_dense_field"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "openai"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: "mock_url"},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid function schema - Invalid function info ", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector},
|
||||
{Name: "output_dense_field", DataType: schemapb.DataType_FloatVector},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
{
|
||||
Name: "text_embedding_func",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_dense_field"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "provider", Value: "UnkownProvider"},
|
||||
{Key: "model_name", Value: "text-embedding-ada-002"},
|
||||
{Key: "api_key", Value: "mock"},
|
||||
{Key: "url", Value: "mock_url"},
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := validateFunction(schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateFunctionInputField(t *testing.T) {
|
||||
t.Run("Valid BM25 function input", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
|
@ -2920,6 +2993,28 @@ func TestValidateFunctionInputField(t *testing.T) {
|
|||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid TextEmbedding function input - multiple fields", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid TextEmbedding function input - wrong type", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err := checkFunctionInputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateFunctionOutputField(t *testing.T) {
|
||||
|
@ -2977,6 +3072,28 @@ func TestValidateFunctionOutputField(t *testing.T) {
|
|||
err := checkFunctionOutputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid TextEmbedding function input - multiple fields", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{}
|
||||
err := checkFunctionOutputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid TextEmbedding function input - wrong type", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
}
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err := checkFunctionOutputField(function, fields)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateFunctionBasicParams(t *testing.T) {
|
||||
|
@ -3078,6 +3195,36 @@ func TestValidateFunctionBasicParams(t *testing.T) {
|
|||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty text embedding params", func(t *testing.T) {
|
||||
function := &schemapb.FunctionSchema{
|
||||
Name: "textEmbeddingParam",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"input1"},
|
||||
OutputFieldNames: []string{"output1"},
|
||||
}
|
||||
err := checkFunctionBasicParams(function)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsBM25FunctionOutputField(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{Name: "input_field", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: "enable_analyzer", Value: "true"}}},
|
||||
{Name: "output_field", DataType: schemapb.DataType_SparseFloatVector, IsFunctionOutput: true},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "bm25_func",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"input_field"},
|
||||
OutputFieldNames: []string{"output_field"},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.False(t, IsBM25FunctionOutputField(schema.Fields[0], schema))
|
||||
assert.True(t, IsBM25FunctionOutputField(schema.Fields[1], schema))
|
||||
}
|
||||
|
||||
func TestComputeRecall(t *testing.T) {
|
||||
|
|
|
@ -70,6 +70,9 @@ func newEmbeddingNode(collectionID int64, channelName string, manager *DataManag
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if functionRunner == nil {
|
||||
continue
|
||||
}
|
||||
node.functionRunners = append(node.functionRunners, functionRunner)
|
||||
}
|
||||
return node, nil
|
||||
|
|
|
@ -382,7 +382,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap
|
|||
}
|
||||
|
||||
for _, field := range collSchema.Fields {
|
||||
if skipFunction && field.GetIsFunctionOutput() {
|
||||
if skipFunction && IsBM25FunctionOutputField(field, collSchema) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -527,7 +527,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
|
|||
}
|
||||
length := 0
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.GetIsFunctionOutput() {
|
||||
if IsBM25FunctionOutputField(field, collSchema) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -1405,3 +1405,22 @@ func (ni NullableInt) GetValue() int {
|
|||
func (ni NullableInt) IsNull() bool {
|
||||
return ni.Value == nil
|
||||
}
|
||||
|
||||
// TODO: unify the function implementation, storage/utils.go & proxy/util.go
|
||||
func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb.CollectionSchema) bool {
|
||||
if !(field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, fSchema := range collSchema.Functions {
|
||||
if fSchema.Type == schemapb.FunctionType_BM25 {
|
||||
if len(fSchema.OutputFieldNames) != 0 && field.Name == fSchema.OutputFieldNames[0] {
|
||||
return true
|
||||
}
|
||||
if len(fSchema.OutputFieldIds) != 0 && field.FieldID == fSchema.OutputFieldIds[0] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1949,3 +1949,53 @@ func TestJson(t *testing.T) {
|
|||
t.Log(string(ExtraBytes))
|
||||
t.Log(ExtraLength)
|
||||
}
|
||||
|
||||
func TestBM25Checker(t *testing.T) {
|
||||
f1 := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
}
|
||||
f2 := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"sparse"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{103},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
IsFunctionOutput: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 103, Name: "sparse",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
IsFunctionOutput: true,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{f1, f2},
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
isBm25 := IsBM25FunctionOutputField(field, schema)
|
||||
if field.FieldID == 103 {
|
||||
assert.True(t, isBm25)
|
||||
} else {
|
||||
assert.False(t, isBm25)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/ali"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type AliEmbeddingProvider struct {
|
||||
fieldDim int64
|
||||
|
||||
client *ali.AliDashScopeEmbedding
|
||||
modelName string
|
||||
embedDimParam int64
|
||||
outputType string
|
||||
|
||||
maxBatch int
|
||||
timeoutSec int64
|
||||
}
|
||||
|
||||
func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv(dashscopeAKEnvStr)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
|
||||
}
|
||||
|
||||
if url == "" {
|
||||
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
}
|
||||
c := ali.NewAliDashScopeEmbeddingClient(apiKey, url)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) {
|
||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName string
|
||||
var dim int64
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case dimParamKey:
|
||||
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
|
||||
modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3)
|
||||
}
|
||||
|
||||
c, err := createAliEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxBatch := 25
|
||||
if modelName == TextEmbeddingV3 {
|
||||
maxBatch = 6
|
||||
}
|
||||
|
||||
provider := AliEmbeddingProvider{
|
||||
client: c,
|
||||
fieldDim: fieldDim,
|
||||
modelName: modelName,
|
||||
embedDimParam: dim,
|
||||
// TextEmbedding only supports dense embedding
|
||||
outputType: "dense",
|
||||
maxBatch: maxBatch,
|
||||
timeoutSec: 30,
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
func (provider *AliEmbeddingProvider) MaxBatch() int {
|
||||
return 5 * provider.maxBatch
|
||||
}
|
||||
|
||||
func (provider *AliEmbeddingProvider) FieldDim() int64 {
|
||||
return provider.fieldDim
|
||||
}
|
||||
|
||||
func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
|
||||
numRows := len(texts)
|
||||
var textType string
|
||||
if mode == SearchMode {
|
||||
textType = "query"
|
||||
} else {
|
||||
textType = "document"
|
||||
}
|
||||
data := make([][]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i += provider.maxBatch {
|
||||
end := i + provider.maxBatch
|
||||
if end > numRows {
|
||||
end = numRows
|
||||
}
|
||||
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if end-i != len(resp.Output.Embeddings) {
|
||||
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Output.Embeddings))
|
||||
}
|
||||
for _, item := range resp.Output.Embeddings {
|
||||
if len(item.Embedding) != int(provider.fieldDim) {
|
||||
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
|
||||
provider.fieldDim, len(item.Embedding))
|
||||
}
|
||||
data = append(data, item.Embedding)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/ali"
|
||||
)
|
||||
|
||||
func TestAliTextEmbeddingProvider(t *testing.T) {
|
||||
suite.Run(t, new(AliTextEmbeddingProviderSuite))
|
||||
}
|
||||
|
||||
type AliTextEmbeddingProviderSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.providers = []string{aliDashScopeProvider}
|
||||
}
|
||||
|
||||
func createAliProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: TextEmbeddingV3},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
switch providerName {
|
||||
case aliDashScopeProvider:
|
||||
return NewAliDashScopeEmbeddingProvider(schema, functionSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknow provider")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) TestEmbedding() {
|
||||
ts := CreateAliEmbeddingServer()
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
s.NoError(err)
|
||||
{
|
||||
data := []string{"sentence"}
|
||||
ret, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(4, len(ret[0]))
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
|
||||
}
|
||||
{
|
||||
data := []string{"sentence 1", "sentence 2", "sentence 3"}
|
||||
ret, _ := provder.CallEmbedding(data, SearchMode)
|
||||
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res ali.EmbeddingResponse
|
||||
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
|
||||
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
TextIndex: 0,
|
||||
})
|
||||
|
||||
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
|
||||
Embedding: []float32{1.0, 1.0},
|
||||
TextIndex: 1,
|
||||
})
|
||||
res.Usage = ali.Usage{
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
for _, providerName := range s.providers {
|
||||
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], providerName)
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res ali.EmbeddingResponse
|
||||
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
|
||||
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
TextIndex: 0,
|
||||
})
|
||||
res.Usage = ali.Usage{
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence2"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) TestCreateAliEmbeddingClient() {
|
||||
_, err := createAliEmbeddingClient("", "")
|
||||
s.Error(err)
|
||||
|
||||
os.Setenv(dashscopeAKEnvStr, "mock_key")
|
||||
defer os.Unsetenv(dashscopeAKEnvStr)
|
||||
_, err = createAliEmbeddingClient("", "")
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *AliTextEmbeddingProviderSuite) TestNewAliDashScopeEmbeddingProvider() {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: "UnkownModels"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
_, err := NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
|
||||
// invalid dim
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: TextEmbeddingV3}
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||
_, err = NewAliDashScopeEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type BedrockClient interface {
|
||||
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
|
||||
}
|
||||
|
||||
type BedrockEmbeddingProvider struct {
|
||||
fieldDim int64
|
||||
|
||||
client BedrockClient
|
||||
modelName string
|
||||
embedDimParam int64
|
||||
normalize bool
|
||||
|
||||
maxBatch int
|
||||
timeoutSec int
|
||||
}
|
||||
|
||||
func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) {
|
||||
if awsAccessKeyId == "" {
|
||||
awsAccessKeyId = os.Getenv(bedrockAccessKeyId)
|
||||
}
|
||||
if awsAccessKeyId == "" {
|
||||
return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId)
|
||||
}
|
||||
|
||||
if awsSecretAccessKey == "" {
|
||||
awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr)
|
||||
}
|
||||
if awsSecretAccessKey == "" {
|
||||
return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr)
|
||||
}
|
||||
if region == "" {
|
||||
return nil, fmt.Errorf("Missing AWS Service region. Please pass `region` param.")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||||
awsAccessKeyId, awsSecretAccessKey, "")),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bedrockruntime.NewFromConfig(cfg), nil
|
||||
}
|
||||
|
||||
func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) {
|
||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var awsAccessKeyId, awsSecretAccessKey, region, modelName string
|
||||
var dim int64
|
||||
normalize := true
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case dimParamKey:
|
||||
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case awsAKIdParamKey:
|
||||
awsAccessKeyId = param.Value
|
||||
case awsSAKParamKey:
|
||||
awsSecretAccessKey = param.Value
|
||||
case regionParamKey:
|
||||
region = param.Value
|
||||
case normalizeParamKey:
|
||||
switch strings.ToLower(param.Value) {
|
||||
case "false":
|
||||
normalize = false
|
||||
case "true":
|
||||
normalize = true
|
||||
default:
|
||||
return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != BedRockTitanTextEmbeddingsV2 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s]",
|
||||
modelName, BedRockTitanTextEmbeddingsV2)
|
||||
}
|
||||
var client BedrockClient
|
||||
if c == nil {
|
||||
client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
client = c
|
||||
}
|
||||
|
||||
return &BedrockEmbeddingProvider{
|
||||
client: client,
|
||||
fieldDim: fieldDim,
|
||||
modelName: modelName,
|
||||
embedDimParam: dim,
|
||||
normalize: normalize,
|
||||
maxBatch: 1,
|
||||
timeoutSec: 30,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (provider *BedrockEmbeddingProvider) MaxBatch() int {
|
||||
// The bedrock model does not support batches, we support a small batch on the milvus side.
|
||||
return 12 * provider.maxBatch
|
||||
}
|
||||
|
||||
func (provider *BedrockEmbeddingProvider) FieldDim() int64 {
|
||||
return provider.fieldDim
|
||||
}
|
||||
|
||||
func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, _ TextEmbeddingMode) ([][]float32, error) {
|
||||
numRows := len(texts)
|
||||
data := make([][]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i += 1 {
|
||||
payload := BedRockRequest{
|
||||
InputText: texts[i],
|
||||
Normalize: provider.normalize,
|
||||
}
|
||||
if provider.embedDimParam != 0 {
|
||||
payload.Dimensions = provider.embedDimParam
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output, err := provider.client.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{
|
||||
Body: payloadBytes,
|
||||
ModelId: aws.String(provider.modelName),
|
||||
ContentType: aws.String("application/json"),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp BedRockResponse
|
||||
err = json.Unmarshal(output.Body, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(resp.Embedding) != int(provider.fieldDim) {
|
||||
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
|
||||
provider.fieldDim, len(resp.Embedding))
|
||||
}
|
||||
data = append(data, resp.Embedding)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type BedRockRequest struct {
|
||||
InputText string `json:"inputText"`
|
||||
Dimensions int64 `json:"dimensions,omitempty"`
|
||||
Normalize bool `json:"normalize,omitempty"`
|
||||
}
|
||||
|
||||
type BedRockResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
InputTextTokenCount int `json:"inputTextTokenCount"`
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestBedrockTextEmbeddingProvider(t *testing.T) {
|
||||
suite.Run(t, new(BedrockTextEmbeddingProviderSuite))
|
||||
}
|
||||
|
||||
type BedrockTextEmbeddingProviderSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *BedrockTextEmbeddingProviderSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.providers = []string{bedrockProvider}
|
||||
}
|
||||
|
||||
func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, dim int) (textEmbeddingProvider, error) {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
switch providerName {
|
||||
case bedrockProvider:
|
||||
return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim})
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknow provider")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() {
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 4)
|
||||
s.NoError(err)
|
||||
{
|
||||
data := []string{"sentence"}
|
||||
ret, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(4, len(ret[0]))
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
|
||||
}
|
||||
{
|
||||
data := []string{"sentence 1", "sentence 2", "sentence 3"}
|
||||
ret, _ := provder.CallEmbedding(data, SearchMode)
|
||||
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 2)
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BedrockTextEmbeddingProviderSuite) TestCreateBedrockClient() {
|
||||
_, err := createBedRockEmbeddingClient("", "", "")
|
||||
s.Error(err)
|
||||
|
||||
_, err = createBedRockEmbeddingClient("mock_id", "", "")
|
||||
s.Error(err)
|
||||
|
||||
_, err = createBedRockEmbeddingClient("", "mock_key", "")
|
||||
s.Error(err)
|
||||
|
||||
_, err = createBedRockEmbeddingClient("mock_id", "mock_key", "")
|
||||
s.Error(err)
|
||||
|
||||
_, err = createBedRockEmbeddingClient("mock_id", "mock_key", "mock_region")
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *BedrockTextEmbeddingProviderSuite) TestNewBedrockEmbeddingProvider() {
|
||||
fieldSchema := &schemapb.FieldSchema{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
}
|
||||
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
|
||||
{Key: awsAKIdParamKey, Value: "mock"},
|
||||
{Key: awsSAKParamKey, Value: "mock"},
|
||||
{Key: regionParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: normalizeParamKey, Value: "false"},
|
||||
},
|
||||
}
|
||||
provider, err := NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.NoError(err)
|
||||
s.True(provider.MaxBatch() > 0)
|
||||
s.Equal(provider.FieldDim(), int64(4))
|
||||
|
||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.NoError(err)
|
||||
|
||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "invalid"}
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.Error(err)
|
||||
|
||||
// invalid model name
|
||||
functionSchema.Params[5] = &commonpb.KeyValuePair{Key: normalizeParamKey, Value: "true"}
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.Error(err)
|
||||
|
||||
// invalid dim
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalid"}
|
||||
_, err = NewBedrockEmbeddingProvider(fieldSchema, functionSchema, nil)
|
||||
s.Error(err)
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type TextEmbeddingMode int
|
||||
|
||||
const (
|
||||
InsertMode TextEmbeddingMode = iota
|
||||
SearchMode
|
||||
)
|
||||
|
||||
// common params
|
||||
const (
|
||||
modelNameParamKey string = "model_name"
|
||||
dimParamKey string = "dim"
|
||||
embeddingURLParamKey string = "url"
|
||||
apiKeyParamKey string = "api_key"
|
||||
)
|
||||
|
||||
// ali text embedding
|
||||
const (
|
||||
TextEmbeddingV1 string = "text-embedding-v1"
|
||||
TextEmbeddingV2 string = "text-embedding-v2"
|
||||
TextEmbeddingV3 string = "text-embedding-v3"
|
||||
|
||||
dashscopeAKEnvStr string = "MILVUSAI_DASHSCOPE_API_KEY"
|
||||
)
|
||||
|
||||
// openai/azure text embedding
|
||||
|
||||
const (
|
||||
TextEmbeddingAda002 string = "text-embedding-ada-002"
|
||||
TextEmbedding3Small string = "text-embedding-3-small"
|
||||
TextEmbedding3Large string = "text-embedding-3-large"
|
||||
|
||||
openaiAKEnvStr string = "MILVUSAI_OPENAI_API_KEY"
|
||||
|
||||
azureOpenaiAKEnvStr string = "MILVUSAI_AZURE_OPENAI_API_KEY"
|
||||
azureOpenaiResourceName string = "MILVUSAI_AZURE_OPENAI_RESOURCE_NAME"
|
||||
|
||||
userParamKey string = "user"
|
||||
)
|
||||
|
||||
// bedrock emebdding
|
||||
|
||||
const (
|
||||
BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0"
|
||||
awsAKIdParamKey string = "aws_access_key_id"
|
||||
awsSAKParamKey string = "aws_secret_access_key"
|
||||
regionParamKey string = "regin"
|
||||
normalizeParamKey string = "normalize"
|
||||
|
||||
bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID"
|
||||
bedrockSAKEnvStr string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY"
|
||||
)
|
||||
|
||||
// vertexAI
|
||||
|
||||
const (
|
||||
locationParamKey string = "location"
|
||||
projectIDParamKey string = "projectid"
|
||||
taskTypeParamKey string = "task"
|
||||
|
||||
textEmbedding005 string = "text-embedding-005"
|
||||
textMultilingualEmbedding002 string = "text-multilingual-embedding-002"
|
||||
|
||||
vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS"
|
||||
)
|
||||
|
||||
// voyageAI
|
||||
const (
|
||||
voyage3Large string = "voyage-3-large"
|
||||
voyage3 string = "voyage-3"
|
||||
voyage3Lite string = "voyage-3-lite"
|
||||
voyageCode3 string = "voyage-code-3"
|
||||
voyageFinance2 string = "voyage-finance-2"
|
||||
voyageLaw2 string = "voyage-law-2"
|
||||
voyageCode2 string = "voyage-code-2"
|
||||
|
||||
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
|
||||
)
|
||||
|
||||
func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("dimension [%s] provided in Function params is not a valid int", dimStr)
|
||||
}
|
||||
|
||||
if dim != 0 && dim != fieldDim {
|
||||
return 0, fmt.Errorf("Function output field:[%s]'s dimension [%d] does not match the dimension [%d] provided in Function params.", fieldName, fieldDim, dim)
|
||||
}
|
||||
return dim, nil
|
||||
}
|
|
@ -35,6 +35,8 @@ func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Functio
|
|||
switch schema.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
return NewBM25FunctionRunner(coll, schema)
|
||||
case schemapb.FunctionType_TextEmbedding:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type FunctionBase struct {
|
||||
schema *schemapb.FunctionSchema
|
||||
outputFields []*schemapb.FieldSchema
|
||||
}
|
||||
|
||||
func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) {
|
||||
var base FunctionBase
|
||||
base.schema = fSchema
|
||||
for _, fieldName := range fSchema.GetOutputFieldNames() {
|
||||
for _, field := range coll.GetFields() {
|
||||
if field.GetName() == fieldName {
|
||||
base.outputFields = append(base.outputFields, field)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(base.outputFields) != len(fSchema.GetOutputFieldNames()) {
|
||||
return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema",
|
||||
coll.Name, fSchema.Name)
|
||||
}
|
||||
return &base, nil
|
||||
}
|
||||
|
||||
func (base *FunctionBase) GetSchema() *schemapb.FunctionSchema {
|
||||
return base.schema
|
||||
}
|
||||
|
||||
func (base *FunctionBase) GetOutputFields() []*schemapb.FieldSchema {
|
||||
return base.outputFields
|
||||
}
|
|
@ -0,0 +1,255 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
GetSchema() *schemapb.FunctionSchema
|
||||
GetOutputFields() []*schemapb.FieldSchema
|
||||
|
||||
MaxBatch() int
|
||||
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
|
||||
ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error)
|
||||
ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error)
|
||||
}
|
||||
|
||||
type FunctionExecutor struct {
|
||||
runners map[int64]Runner
|
||||
}
|
||||
|
||||
func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (Runner, error) {
|
||||
switch schema.GetType() {
|
||||
case schemapb.FunctionType_BM25: // ignore bm25 function
|
||||
return nil, nil
|
||||
case schemapb.FunctionType_TextEmbedding:
|
||||
f, err := NewTextEmbeddingFunction(coll, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String())
|
||||
}
|
||||
}
|
||||
|
||||
// Since bm25 and embedding are implemented in different ways, the bm25 function is not verified here.
|
||||
func ValidateFunctions(schema *schemapb.CollectionSchema) error {
|
||||
for _, fSchema := range schema.Functions {
|
||||
if _, err := createFunction(schema, fSchema); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
|
||||
executor := &FunctionExecutor{
|
||||
runners: make(map[int64]Runner),
|
||||
}
|
||||
for _, fSchema := range schema.Functions {
|
||||
runner, err := createFunction(schema, fSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if runner != nil {
|
||||
executor.runners[fSchema.GetOutputFieldIds()[0]] = runner
|
||||
}
|
||||
}
|
||||
return executor, nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
|
||||
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames()))
|
||||
for _, name := range runner.GetSchema().GetInputFieldNames() {
|
||||
for _, field := range msg.FieldsData {
|
||||
if field.GetFieldName() == name {
|
||||
inputs = append(inputs, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(inputs) != len(runner.GetSchema().InputFieldIds) {
|
||||
return nil, fmt.Errorf("Input field not found")
|
||||
}
|
||||
|
||||
outputs, err := runner.ProcessInsert(inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error {
|
||||
numRows := msg.NumRows
|
||||
for _, runner := range executor.runners {
|
||||
if numRows > uint64(runner.MaxBatch()) {
|
||||
return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch())
|
||||
}
|
||||
}
|
||||
|
||||
outputs := make(chan []*schemapb.FieldData, len(executor.runners))
|
||||
errChan := make(chan error, len(executor.runners))
|
||||
var wg sync.WaitGroup
|
||||
for _, runner := range executor.runners {
|
||||
wg.Add(1)
|
||||
go func(runner Runner) {
|
||||
defer wg.Done()
|
||||
data, err := executor.processSingleFunction(runner, msg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
outputs <- data
|
||||
}(runner)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
close(outputs)
|
||||
|
||||
// Collect all errors
|
||||
var errs []error
|
||||
for err := range errChan {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("multiple errors occurred: %v", errs)
|
||||
}
|
||||
|
||||
for output := range outputs {
|
||||
msg.FieldsData = append(msg.FieldsData, output...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) {
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroup, pb)
|
||||
if len(pb.Placeholders) != 1 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("No placeholders founded")
|
||||
}
|
||||
if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar {
|
||||
return placeholderGroup, nil
|
||||
}
|
||||
res, err := runner.ProcessSearch(pb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proto.Marshal(res)
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error {
|
||||
runner, exist := executor.runners[req.FieldId]
|
||||
if !exist {
|
||||
return fmt.Errorf("Can not found function in field %d", req.FieldId)
|
||||
}
|
||||
if req.Nq > int64(runner.MaxBatch()) {
|
||||
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch())
|
||||
}
|
||||
if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil {
|
||||
return err
|
||||
} else {
|
||||
req.PlaceholderGroup = newHolder
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error {
|
||||
outputs := make(chan map[int64][]byte, len(req.GetSubReqs()))
|
||||
errChan := make(chan error, len(req.GetSubReqs()))
|
||||
var wg sync.WaitGroup
|
||||
for idx, sub := range req.GetSubReqs() {
|
||||
if runner, exist := executor.runners[sub.FieldId]; exist {
|
||||
if sub.Nq > int64(runner.MaxBatch()) {
|
||||
return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch())
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(runner Runner, idx int64, placeholderGroup []byte) {
|
||||
defer wg.Done()
|
||||
if newHolder, err := executor.processSingleSearch(runner, placeholderGroup); err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
outputs <- map[int64][]byte{idx: newHolder}
|
||||
}
|
||||
}(runner, int64(idx), sub.GetPlaceholderGroup())
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
close(outputs)
|
||||
for err := range errChan {
|
||||
return err
|
||||
}
|
||||
|
||||
for output := range outputs {
|
||||
for idx, holder := range output {
|
||||
req.SubReqs[idx].PlaceholderGroup = holder
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error {
|
||||
if !req.IsAdvanced {
|
||||
return executor.prcessSearch(req)
|
||||
}
|
||||
return executor.prcessAdvanceSearch(req)
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) {
|
||||
inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds))
|
||||
for idx, id := range runner.GetSchema().InputFieldIds {
|
||||
field, exist := data.Data[id]
|
||||
if !exist {
|
||||
return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx])
|
||||
}
|
||||
inputs = append(inputs, field)
|
||||
}
|
||||
|
||||
outputs, err := runner.ProcessBulkInsert(inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error {
|
||||
// Since concurrency has already been used in the outer layer, only a serial logic access model is used here.
|
||||
for _, runner := range executor.runners {
|
||||
output, err := executor.processSingleBulkInsert(runner, data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for k, v := range output {
|
||||
data.Data[k] = v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,340 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/openai"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
)
|
||||
|
||||
func TestFunctionExecutor(t *testing.T) {
|
||||
suite.Run(t, new(FunctionExecutorSuite))
|
||||
}
|
||||
|
||||
type FunctionExecutorSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema {
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
IsFunctionOutput: true,
|
||||
},
|
||||
{
|
||||
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "8"},
|
||||
},
|
||||
IsFunctionOutput: true,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{
|
||||
{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{101},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{102},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldIds: []int64{101},
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldIds: []int64{103},
|
||||
OutputFieldNames: []string{"vector2"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: dimParamKey, Value: "8"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg {
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
FieldName: "text",
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: texts,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
|
||||
msg := msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
FieldsData: data,
|
||||
},
|
||||
}
|
||||
return &msg
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]float32 {
|
||||
embeddings := make([][]float32, 0)
|
||||
for i := 0; i < len(texts); i++ {
|
||||
f := float32(i)
|
||||
emb := make([]float32, 0)
|
||||
for j := 0; j < dim; j++ {
|
||||
emb = append(emb, f+float32(j)*0.1)
|
||||
}
|
||||
embeddings = append(embeddings, emb)
|
||||
}
|
||||
return embeddings
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) TestExecutor() {
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
schema := s.creataSchema(ts.URL)
|
||||
exec, err := NewFunctionExecutor(schema)
|
||||
s.NoError(err)
|
||||
msg := s.createMsg([]string{"sentence", "sentence"})
|
||||
exec.ProcessInsert(msg)
|
||||
s.Equal(len(msg.FieldsData), 3)
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) TestErrorEmbedding() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req openai.EmbeddingRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
var res openai.EmbeddingResponse
|
||||
res.Object = "list"
|
||||
res.Model = "text-embedding-3-small"
|
||||
for i := 0; i < len(req.Input); i++ {
|
||||
res.Data = append(res.Data, openai.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{},
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: 1,
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer ts.Close()
|
||||
schema := s.creataSchema(ts.URL)
|
||||
exec, err := NewFunctionExecutor(schema)
|
||||
fmt.Println(err)
|
||||
s.NoError(err)
|
||||
msg := s.createMsg([]string{"sentence", "sentence"})
|
||||
err = exec.ProcessInsert(msg)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) TestErrorSchema() {
|
||||
schema := s.creataSchema("http://localhost")
|
||||
schema.Functions[0].Type = schemapb.FunctionType_Unknown
|
||||
_, err := NewFunctionExecutor(schema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) TestInternalPrcessSearch() {
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
schema := s.creataSchema(ts.URL)
|
||||
exec, err := NewFunctionExecutor(schema)
|
||||
s.NoError(err)
|
||||
|
||||
{
|
||||
f := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strings.Split("helle,world", ","),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
|
||||
s.NoError(err)
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
Nq: 2,
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
IsAdvanced: false,
|
||||
FieldId: 102,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
s.NoError(err)
|
||||
|
||||
// No function found
|
||||
req = &internalpb.SearchRequest{
|
||||
Nq: 2,
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
IsAdvanced: false,
|
||||
FieldId: 111,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
s.Error(err)
|
||||
|
||||
// Large search nq
|
||||
req = &internalpb.SearchRequest{
|
||||
Nq: 1000,
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
IsAdvanced: false,
|
||||
FieldId: 102,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// AdvanceSearch
|
||||
{
|
||||
f := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strings.Split("helle,world", ","),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
|
||||
s.NoError(err)
|
||||
|
||||
subReq := &internalpb.SubSearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
Nq: 2,
|
||||
FieldId: 102,
|
||||
}
|
||||
req := &internalpb.SearchRequest{
|
||||
IsAdvanced: true,
|
||||
SubReqs: []*internalpb.SubSearchRequest{subReq},
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
s.NoError(err)
|
||||
|
||||
// Large nq
|
||||
subReq.Nq = 1000
|
||||
err = exec.ProcessSearch(req)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FunctionExecutorSuite) TestInternalPrcessSearchFailed() {
|
||||
ts := CreateErrorEmbeddingServer()
|
||||
defer ts.Close()
|
||||
|
||||
schema := s.creataSchema(ts.URL)
|
||||
exec, err := NewFunctionExecutor(schema)
|
||||
s.NoError(err)
|
||||
f := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strings.Split("helle,world", ","),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
|
||||
s.NoError(err)
|
||||
|
||||
{
|
||||
req := &internalpb.SearchRequest{
|
||||
Nq: 2,
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
IsAdvanced: false,
|
||||
FieldId: 102,
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
s.Error(err)
|
||||
}
|
||||
// AdvanceSearch
|
||||
{
|
||||
subReq := &internalpb.SubSearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
Nq: 2,
|
||||
FieldId: 102,
|
||||
}
|
||||
req := &internalpb.SearchRequest{
|
||||
IsAdvanced: true,
|
||||
SubReqs: []*internalpb.SubSearchRequest{subReq},
|
||||
}
|
||||
err = exec.ProcessSearch(req)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// Determine whether the column corresponding to outputIDs contains functions, except bm25 function,
|
||||
// if outputIDs is empty, check all cols
|
||||
func HasNonBM25Functions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool {
|
||||
for _, fSchema := range functions {
|
||||
switch fSchema.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
case schemapb.FunctionType_Unknown:
|
||||
default:
|
||||
if len(outputIDs) == 0 {
|
||||
return true
|
||||
} else {
|
||||
for _, id := range outputIDs {
|
||||
for _, fOutputID := range fSchema.GetOutputFieldIds() {
|
||||
if fOutputID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/ali"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/openai"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
|
||||
)
|
||||
|
||||
func mockEmbedding(texts []string, dim int) [][]float32 {
|
||||
embeddings := make([][]float32, 0)
|
||||
for i := 0; i < len(texts); i++ {
|
||||
f := float32(i)
|
||||
emb := make([]float32, 0)
|
||||
for j := 0; j < dim; j++ {
|
||||
emb = append(emb, f+float32(j)*0.1)
|
||||
}
|
||||
embeddings = append(embeddings, emb)
|
||||
}
|
||||
return embeddings
|
||||
}
|
||||
|
||||
func CreateErrorEmbeddingServer() *httptest.Server {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
return ts
|
||||
}
|
||||
|
||||
func CreateOpenAIEmbeddingServer() *httptest.Server {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req openai.EmbeddingRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
json.Unmarshal(body, &req)
|
||||
embs := mockEmbedding(req.Input, req.Dimensions)
|
||||
var res openai.EmbeddingResponse
|
||||
res.Object = "list"
|
||||
res.Model = "text-embedding-3-small"
|
||||
for i := 0; i < len(req.Input); i++ {
|
||||
res.Data = append(res.Data, openai.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Embedding: embs[i],
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: 1,
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
return ts
|
||||
}
|
||||
|
||||
func CreateAliEmbeddingServer() *httptest.Server {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req ali.EmbeddingRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
json.Unmarshal(body, &req)
|
||||
embs := mockEmbedding(req.Input.Texts, req.Parameters.Dimension)
|
||||
var res ali.EmbeddingResponse
|
||||
for i := 0; i < len(req.Input.Texts); i++ {
|
||||
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
|
||||
Embedding: embs[i],
|
||||
TextIndex: i,
|
||||
})
|
||||
}
|
||||
|
||||
res.Usage = ali.Usage{
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
return ts
|
||||
}
|
||||
|
||||
func CreateVoyageAIEmbeddingServer() *httptest.Server {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req voyageai.EmbeddingRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
json.Unmarshal(body, &req)
|
||||
embs := mockEmbedding(req.Input, int(req.OutputDimension))
|
||||
var res voyageai.EmbeddingResponse
|
||||
for i := 0; i < len(req.Input); i++ {
|
||||
res.Data = append(res.Data, voyageai.EmbeddingData{
|
||||
Object: "list",
|
||||
Embedding: embs[i],
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
res.Usage = voyageai.Usage{
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
return ts
|
||||
}
|
||||
|
||||
func CreateVertexAIEmbeddingServer() *httptest.Server {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req vertexai.EmbeddingRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
json.Unmarshal(body, &req)
|
||||
var texts []string
|
||||
for _, item := range req.Instances {
|
||||
texts = append(texts, item.Content)
|
||||
}
|
||||
embs := mockEmbedding(texts, int(req.Parameters.OutputDimensionality))
|
||||
var res vertexai.EmbeddingResponse
|
||||
for i := 0; i < len(req.Instances); i++ {
|
||||
res.Predictions = append(res.Predictions, vertexai.Prediction{
|
||||
Embeddings: vertexai.Embeddings{
|
||||
Statistics: vertexai.Statistics{
|
||||
Truncated: false,
|
||||
TokenCount: 10,
|
||||
},
|
||||
Values: embs[i],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
res.Metadata = vertexai.Metadata{
|
||||
BillableCharacterCount: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
return ts
|
||||
}
|
||||
|
||||
type MockBedrockClient struct {
|
||||
dim int
|
||||
}
|
||||
|
||||
func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) {
|
||||
var req BedRockRequest
|
||||
json.Unmarshal(params.Body, &req)
|
||||
embs := mockEmbedding([]string{req.InputText}, c.dim)
|
||||
|
||||
var resp BedRockResponse
|
||||
resp.Embedding = embs[0]
|
||||
resp.InputTextTokenCount = 2
|
||||
body, _ := json.Marshal(resp)
|
||||
return &bedrockruntime.InvokeModelOutput{Body: body}, nil
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/utils"
|
||||
)
|
||||
|
||||
type Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
Dimension int `json:"dimension,omitempty"`
|
||||
OutputType string `json:"output_type,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
// ID of the model to use.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Input text to embed, encoded as a string.
|
||||
Input Input `json:"input"`
|
||||
|
||||
Parameters Parameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
// The total number of tokens used by the request.
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type SparseEmbedding struct {
|
||||
Index int `json:"index"`
|
||||
Value float32 `json:"value"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type Embeddings struct {
|
||||
TextIndex int `json:"text_index"`
|
||||
Embedding []float32 `json:"embedding,omitempty"`
|
||||
SparseEmbedding []SparseEmbedding `json:"sparse_embedding,omitempty"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Embeddings []Embeddings `json:"embeddings"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type ByIndex struct {
|
||||
resp *EmbeddingResponse
|
||||
}
|
||||
|
||||
func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) }
|
||||
func (eb *ByIndex) Swap(i, j int) {
|
||||
eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i]
|
||||
}
|
||||
|
||||
func (eb *ByIndex) Less(i, j int) bool {
|
||||
return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex
|
||||
}
|
||||
|
||||
type ErrorInfo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliDashScopeEmbedding struct {
|
||||
apiKey string
|
||||
url string
|
||||
}
|
||||
|
||||
func NewAliDashScopeEmbeddingClient(apiKey string, url string) *AliDashScopeEmbedding {
|
||||
return &AliDashScopeEmbedding{
|
||||
apiKey: apiKey,
|
||||
url: url,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AliDashScopeEmbedding) Check() error {
|
||||
if c.apiKey == "" {
|
||||
return fmt.Errorf("api key is empty")
|
||||
}
|
||||
|
||||
if c.url == "" {
|
||||
return fmt.Errorf("url is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
|
||||
var r EmbeddingRequest
|
||||
r.Model = modelName
|
||||
r.Input = Input{texts}
|
||||
r.Parameters.Dimension = dim
|
||||
r.Parameters.TextType = textType
|
||||
r.Parameters.OutputType = outputType
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeoutSec <= 0 {
|
||||
timeoutSec = utils.DefaultTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var res EmbeddingResponse
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Sort(&ByIndex{&res})
|
||||
return &res, err
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEmbeddingClientCheck(t *testing.T) {
|
||||
{
|
||||
c := NewAliDashScopeEmbeddingClient("", "mock_uri")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewAliDashScopeEmbeddingClient("mock_key", "")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewAliDashScopeEmbeddingClient("mock_key", "mock_uri")
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingOK(t *testing.T) {
|
||||
var res EmbeddingResponse
|
||||
repStr := `{
|
||||
"output": {
|
||||
"embeddings": [
|
||||
{
|
||||
"text_index": 1,
|
||||
"embedding": [0.1]
|
||||
},
|
||||
{
|
||||
"text_index": 0,
|
||||
"embedding": [0.0]
|
||||
},
|
||||
{
|
||||
"text_index": 2,
|
||||
"embedding": [0.2]
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": {
|
||||
"total_tokens": 100
|
||||
},
|
||||
"request_id": "0000000000000"
|
||||
}`
|
||||
err := json.Unmarshal([]byte(repStr), &res)
|
||||
assert.NoError(t, err)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewAliDashScopeEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
ret, err := c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0)
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Output.Embeddings[0].TextIndex, 0)
|
||||
assert.Equal(t, ret.Output.Embeddings[1].TextIndex, 1)
|
||||
assert.Equal(t, ret.Output.Embeddings[2].TextIndex, 2)
|
||||
assert.Equal(t, ret.Output.Embeddings[0].Embedding, []float32{0.0})
|
||||
assert.Equal(t, ret.Output.Embeddings[1].Embedding, []float32{0.1})
|
||||
assert.Equal(t, ret.Output.Embeddings[2].Embedding, []float32{0.2})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingFailed(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewAliDashScopeEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0)
|
||||
assert.True(t, err != nil)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/utils"
|
||||
)
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
// ID of the model to use.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Input text to embed, encoded as a string.
|
||||
Input []string `json:"input"`
|
||||
|
||||
// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
||||
User string `json:"user,omitempty"`
|
||||
|
||||
// The format to return the embeddings in. Can be either float or base64.
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
|
||||
// The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
// The number of tokens used by the prompt.
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
|
||||
// The total number of tokens used by the request.
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
// The object type, which is always "embedding".
|
||||
Object string `json:"object"`
|
||||
|
||||
// The embedding vector, which is a list of floats.
|
||||
Embedding []float32 `json:"embedding"`
|
||||
|
||||
// The index of the embedding in the list of embeddings.
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
// The object type, which is always "list".
|
||||
Object string `json:"object"`
|
||||
|
||||
// The list of embeddings generated by the model.
|
||||
Data []EmbeddingData `json:"data"`
|
||||
|
||||
// The name of the model used to generate the embedding.
|
||||
Model string `json:"model"`
|
||||
|
||||
// The usage information for the request.
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ByIndex struct {
|
||||
resp *EmbeddingResponse
|
||||
}
|
||||
|
||||
func (eb *ByIndex) Len() int { return len(eb.resp.Data) }
|
||||
func (eb *ByIndex) Swap(i, j int) {
|
||||
eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i]
|
||||
}
|
||||
func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index }
|
||||
|
||||
type ErrorInfo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param string `json:"param,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type EmbedddingError struct {
|
||||
Error ErrorInfo `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingInterface interface {
|
||||
Check() error
|
||||
Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error)
|
||||
}
|
||||
|
||||
type openAIBase struct {
|
||||
apiKey string
|
||||
url string
|
||||
}
|
||||
|
||||
func (c *openAIBase) Check() error {
|
||||
if c.apiKey == "" {
|
||||
return fmt.Errorf("api key is empty")
|
||||
}
|
||||
|
||||
if c.url == "" {
|
||||
return fmt.Errorf("url is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openAIBase) genReq(modelName string, texts []string, dim int, user string) *EmbeddingRequest {
|
||||
var r EmbeddingRequest
|
||||
r.Model = modelName
|
||||
r.Input = texts
|
||||
r.EncodingFormat = "float"
|
||||
if user != "" {
|
||||
r.User = user
|
||||
}
|
||||
if dim != 0 {
|
||||
r.Dimensions = dim
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
|
||||
r := c.genReq(modelName, texts, dim, user)
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeoutSec <= 0 {
|
||||
timeoutSec = utils.DefaultTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var res EmbeddingResponse
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Sort(&ByIndex{&res})
|
||||
return &res, err
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingClient struct {
|
||||
openAIBase
|
||||
}
|
||||
|
||||
func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient {
|
||||
return &OpenAIEmbeddingClient{
|
||||
openAIBase{
|
||||
apiKey: apiKey,
|
||||
url: url,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec)
|
||||
}
|
||||
|
||||
type AzureOpenAIEmbeddingClient struct {
|
||||
openAIBase
|
||||
apiVersion string
|
||||
}
|
||||
|
||||
func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbeddingClient {
|
||||
return &AzureOpenAIEmbeddingClient{
|
||||
openAIBase: openAIBase{
|
||||
apiKey: apiKey,
|
||||
url: url,
|
||||
},
|
||||
apiVersion: "2024-06-01",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
|
||||
base, err := url.Parse(c.url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := fmt.Sprintf("/openai/deployments/%s/embeddings", modelName)
|
||||
base.Path = path
|
||||
params := url.Values{}
|
||||
params.Add("api-version", c.apiVersion)
|
||||
base.RawQuery = params.Encode()
|
||||
url := base.String()
|
||||
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"api-key": c.apiKey,
|
||||
}
|
||||
return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec)
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEmbeddingClientCheck(t *testing.T) {
|
||||
{
|
||||
c := NewOpenAIEmbeddingClient("", "mock_uri")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewOpenAIEmbeddingClient("mock_key", "")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewOpenAIEmbeddingClient("mock_key", "mock_uri")
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingOK(t *testing.T) {
|
||||
var res EmbeddingResponse
|
||||
res.Object = "list"
|
||||
res.Model = "text-embedding-3-small"
|
||||
res.Data = []EmbeddingData{
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
|
||||
Index: 1,
|
||||
},
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
|
||||
Index: 0,
|
||||
},
|
||||
}
|
||||
res.Usage = Usage{
|
||||
PromptTokens: 1,
|
||||
TotalTokens: 100,
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
if r.Header["Authorization"][0] != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
} else {
|
||||
if r.Header["Api-Key"][0] != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Data[0].Index, 0)
|
||||
assert.Equal(t, ret.Data[1].Index, 1)
|
||||
}
|
||||
{
|
||||
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Data[0].Index, 0)
|
||||
assert.Equal(t, ret.Data[1].Index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingRetry(t *testing.T) {
|
||||
var res EmbeddingResponse
|
||||
res.Object = "list"
|
||||
res.Model = "text-embedding-3-small"
|
||||
res.Data = []EmbeddingData{
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.1, 2.2, 3.2, 4.5},
|
||||
Index: 2,
|
||||
},
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
|
||||
Index: 0,
|
||||
},
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.1, 2.2, 3.2, 4.3},
|
||||
Index: 1,
|
||||
},
|
||||
}
|
||||
res.Usage = Usage{
|
||||
PromptTokens: 1,
|
||||
TotalTokens: 100,
|
||||
}
|
||||
|
||||
var count int32 = 0
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if atomic.LoadInt32(&count) < 2 {
|
||||
atomic.AddInt32(&count, 1)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Usage, res.Usage)
|
||||
assert.Equal(t, ret.Object, res.Object)
|
||||
assert.Equal(t, ret.Model, res.Model)
|
||||
assert.Equal(t, ret.Data[0], res.Data[1])
|
||||
assert.Equal(t, ret.Data[1], res.Data[2])
|
||||
assert.Equal(t, ret.Data[2], res.Data[0])
|
||||
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
|
||||
}
|
||||
{
|
||||
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Usage, res.Usage)
|
||||
assert.Equal(t, ret.Object, res.Object)
|
||||
assert.Equal(t, ret.Model, res.Model)
|
||||
assert.Equal(t, ret.Data[0], res.Data[1])
|
||||
assert.Equal(t, ret.Data[1], res.Data[2])
|
||||
assert.Equal(t, ret.Data[2], res.Data[0])
|
||||
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingFailed(t *testing.T) {
|
||||
var count int32 = 0
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&count, 1)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
atomic.StoreInt32(&count, 0)
|
||||
c := NewOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
|
||||
assert.True(t, err != nil)
|
||||
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
|
||||
}
|
||||
{
|
||||
atomic.StoreInt32(&count, 0)
|
||||
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
|
||||
assert.True(t, err != nil)
|
||||
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
var st int32 = 0
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(3 * time.Second)
|
||||
atomic.AddInt32(&st, 1)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
atomic.StoreInt32(&st, 0)
|
||||
c := NewOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1)
|
||||
assert.True(t, err != nil)
|
||||
assert.Equal(t, atomic.LoadInt32(&st), int32(0))
|
||||
time.Sleep(3 * time.Second)
|
||||
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
|
||||
}
|
||||
|
||||
{
|
||||
atomic.StoreInt32(&st, 0)
|
||||
c := NewAzureOpenAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1)
|
||||
assert.True(t, err != nil)
|
||||
assert.Equal(t, atomic.LoadInt32(&st), int32(0))
|
||||
time.Sleep(3 * time.Second)
|
||||
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const DefaultTimeout int64 = 30
|
||||
|
||||
func send(req *http.Request) ([]byte, error) {
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf(string(body))
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func RetrySend(req *http.Request, maxRetries int) ([]byte, error) {
|
||||
var err error
|
||||
var res []byte
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
res, err = send(req)
|
||||
if err == nil {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/utils"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
TaskType string `json:"task_type,omitempty"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
OutputDimensionality int64 `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Instances []Instance `json:"instances"`
|
||||
Parameters Parameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type Statistics struct {
|
||||
Truncated bool `json:"truncated"`
|
||||
TokenCount int `json:"token_count"`
|
||||
}
|
||||
|
||||
type Embeddings struct {
|
||||
Statistics Statistics `json:"statistics"`
|
||||
Values []float32 `json:"values"`
|
||||
}
|
||||
|
||||
type Prediction struct {
|
||||
Embeddings Embeddings `json:"embeddings"`
|
||||
}
|
||||
|
||||
type Metadata struct {
|
||||
BillableCharacterCount int `json:"billableCharacterCount"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Predictions []Prediction `json:"predictions"`
|
||||
Metadata Metadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type ErrorInfo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type VertexAIEmbedding struct {
|
||||
url string
|
||||
jsonKey []byte
|
||||
scopes string
|
||||
token string
|
||||
}
|
||||
|
||||
func NewVertexAIEmbedding(url string, jsonKey []byte, scopes string, token string) *VertexAIEmbedding {
|
||||
return &VertexAIEmbedding{
|
||||
url: url,
|
||||
jsonKey: jsonKey,
|
||||
scopes: scopes,
|
||||
token: token,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *VertexAIEmbedding) Check() error {
|
||||
if c.url == "" {
|
||||
return fmt.Errorf("VertexAI embedding url is empty")
|
||||
}
|
||||
if len(c.jsonKey) == 0 {
|
||||
return fmt.Errorf("jsonKey is empty")
|
||||
}
|
||||
if c.scopes == "" {
|
||||
return fmt.Errorf("Scopes param is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *VertexAIEmbedding) getAccessToken() (string, error) {
|
||||
ctx := context.Background()
|
||||
creds, err := google.CredentialsFromJSON(ctx, c.jsonKey, c.scopes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Failed to find credentials: %v", err)
|
||||
}
|
||||
token, err := creds.TokenSource.Token()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Failed to get token: %v", err)
|
||||
}
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec int64) (*EmbeddingResponse, error) {
|
||||
var r EmbeddingRequest
|
||||
for _, text := range texts {
|
||||
r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text})
|
||||
}
|
||||
if dim != 0 {
|
||||
r.Parameters.OutputDimensionality = dim
|
||||
}
|
||||
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeoutSec <= 0 {
|
||||
timeoutSec = utils.DefaultTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var token string
|
||||
if c.token != "" {
|
||||
token = c.token
|
||||
} else {
|
||||
token, err = c.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var res EmbeddingResponse
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, err
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEmbeddingClientCheck(t *testing.T) {
|
||||
mockJSONKey := []byte{1, 2, 3}
|
||||
{
|
||||
c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewVertexAIEmbedding("", mockJSONKey, "", "")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "")
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingOK(t *testing.T) {
|
||||
var res EmbeddingResponse
|
||||
repStr := `{"predictions": [{"embeddings": {"statistics": {"truncated": false, "token_count": 4}, "values": [-0.028420744463801384, 0.037183016538619995]}}, {"embeddings": {"statistics": {"truncated": false, "token_count": 8}, "values": [-0.04367655888199806, 0.03777721896767616, 0.0158217903226614]}}], "metadata": {"billableCharacterCount": 27}}`
|
||||
err := json.Unmarshal([]byte(repStr), &res)
|
||||
assert.NoError(t, err)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token")
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-005", []string{"sentence"}, 0, "query", 0)
|
||||
assert.True(t, err == nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingFailed(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token")
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", 0)
|
||||
assert.True(t, err != nil)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package voyageai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/utils"
|
||||
)
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
// ID of the model to use.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Input text to embed, encoded as a string.
|
||||
Input []string `json:"input"`
|
||||
|
||||
InputType string `json:"input_type,omitempty"`
|
||||
|
||||
Truncation bool `json:"truncation,omitempty"`
|
||||
|
||||
OutputDimension int64 `json:"output_dimension,omitempty"`
|
||||
|
||||
OutputDtype string `json:"output_dtype,omitempty"`
|
||||
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
// The total number of tokens used by the request.
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
|
||||
Embedding []float32 `json:"embedding"`
|
||||
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
|
||||
Data []EmbeddingData `json:"data"`
|
||||
|
||||
Model string `json:"model"`
|
||||
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ByIndex struct {
|
||||
resp *EmbeddingResponse
|
||||
}
|
||||
|
||||
func (eb *ByIndex) Len() int { return len(eb.resp.Data) }
|
||||
func (eb *ByIndex) Swap(i, j int) {
|
||||
eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i]
|
||||
}
|
||||
|
||||
func (eb *ByIndex) Less(i, j int) bool {
|
||||
return eb.resp.Data[i].Index < eb.resp.Data[j].Index
|
||||
}
|
||||
|
||||
type ErrorInfo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type VoyageAIEmbedding struct {
|
||||
apiKey string
|
||||
url string
|
||||
}
|
||||
|
||||
func NewVoyageAIEmbeddingClient(apiKey string, url string) *VoyageAIEmbedding {
|
||||
return &VoyageAIEmbedding{
|
||||
apiKey: apiKey,
|
||||
url: url,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *VoyageAIEmbedding) Check() error {
|
||||
if c.apiKey == "" {
|
||||
return fmt.Errorf("api key is empty")
|
||||
}
|
||||
|
||||
if c.url == "" {
|
||||
return fmt.Errorf("url is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
|
||||
var r EmbeddingRequest
|
||||
r.Model = modelName
|
||||
r.Input = texts
|
||||
r.InputType = textType
|
||||
r.OutputDtype = outputType
|
||||
if dim != 0 {
|
||||
r.OutputDimension = int64(dim)
|
||||
}
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeoutSec <= 0 {
|
||||
timeoutSec = utils.DefaultTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
body, err := utils.RetrySend(req, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var res EmbeddingResponse
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Sort(&ByIndex{&res})
|
||||
return &res, err
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package voyageai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEmbeddingClientCheck(t *testing.T) {
|
||||
{
|
||||
c := NewVoyageAIEmbeddingClient("", "mock_uri")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", "")
|
||||
err := c.Check()
|
||||
assert.True(t, err != nil)
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
{
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", "mock_uri")
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingOK(t *testing.T) {
|
||||
var res EmbeddingResponse
|
||||
repStr := `{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0.0,
|
||||
0.1
|
||||
],
|
||||
"index": 0
|
||||
},
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
2.0,
|
||||
2.1
|
||||
],
|
||||
"index": 2
|
||||
},
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
1.0,
|
||||
1.1
|
||||
],
|
||||
"index": 1
|
||||
}
|
||||
],
|
||||
"model": "voyage-large-2",
|
||||
"usage": {
|
||||
"total_tokens": 10
|
||||
}
|
||||
}`
|
||||
err := json.Unmarshal([]byte(repStr), &res)
|
||||
assert.NoError(t, err)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
ret, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, ret.Data[0].Index, 0)
|
||||
assert.Equal(t, ret.Data[1].Index, 1)
|
||||
assert.Equal(t, ret.Data[2].Index, 2)
|
||||
assert.Equal(t, ret.Data[0].Embedding, []float32{0.0, 0.1})
|
||||
assert.Equal(t, ret.Data[1].Embedding, []float32{1.0, 1.1})
|
||||
assert.Equal(t, ret.Data[2].Embedding, []float32{2.0, 2.1})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingFailed(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
url := ts.URL
|
||||
|
||||
{
|
||||
c := NewVoyageAIEmbeddingClient("mock_key", url)
|
||||
err := c.Check()
|
||||
assert.True(t, err == nil)
|
||||
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
|
||||
assert.True(t, err != nil)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/openai"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type OpenAIEmbeddingProvider struct {
|
||||
fieldDim int64
|
||||
|
||||
client openai.OpenAIEmbeddingInterface
|
||||
modelName string
|
||||
embedDimParam int64
|
||||
user string
|
||||
|
||||
maxBatch int
|
||||
timeoutSec int64
|
||||
}
|
||||
|
||||
func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) {
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv(openaiAKEnvStr)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr)
|
||||
}
|
||||
|
||||
if url == "" {
|
||||
url = "https://api.openai.com/v1/embeddings"
|
||||
}
|
||||
|
||||
c := openai.NewOpenAIEmbeddingClient(apiKey, url)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) {
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv(azureOpenaiAKEnvStr)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr)
|
||||
}
|
||||
|
||||
if url == "" {
|
||||
if resourceName := os.Getenv(azureOpenaiResourceName); resourceName != "" {
|
||||
url = fmt.Sprintf("https://%s.openai.azure.com", resourceName)
|
||||
}
|
||||
}
|
||||
if url == "" {
|
||||
return nil, fmt.Errorf("Must configure the %s environment variable in the Milvus service", azureOpenaiResourceName)
|
||||
}
|
||||
c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, isAzure bool) (*OpenAIEmbeddingProvider, error) {
|
||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName, user string
|
||||
var dim int64
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case dimParamKey:
|
||||
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case userParamKey:
|
||||
user = param.Value
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
var c openai.OpenAIEmbeddingInterface
|
||||
if !isAzure {
|
||||
if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
|
||||
modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large)
|
||||
}
|
||||
|
||||
c, err = createOpenAIEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
c, err = createAzureOpenAIEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
provider := OpenAIEmbeddingProvider{
|
||||
client: c,
|
||||
fieldDim: fieldDim,
|
||||
modelName: modelName,
|
||||
user: user,
|
||||
embedDimParam: dim,
|
||||
maxBatch: 128,
|
||||
timeoutSec: 30,
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) {
|
||||
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, false)
|
||||
}
|
||||
|
||||
func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) {
|
||||
return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, true)
|
||||
}
|
||||
|
||||
func (provider *OpenAIEmbeddingProvider) MaxBatch() int {
|
||||
return 5 * provider.maxBatch
|
||||
}
|
||||
|
||||
func (provider *OpenAIEmbeddingProvider) FieldDim() int64 {
|
||||
return provider.fieldDim
|
||||
}
|
||||
|
||||
func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, _ TextEmbeddingMode) ([][]float32, error) {
|
||||
numRows := len(texts)
|
||||
data := make([][]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i += provider.maxBatch {
|
||||
end := i + provider.maxBatch
|
||||
if end > numRows {
|
||||
end = numRows
|
||||
}
|
||||
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), provider.user, provider.timeoutSec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if end-i != len(resp.Data) {
|
||||
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data))
|
||||
}
|
||||
for _, item := range resp.Data {
|
||||
if len(item.Embedding) != int(provider.fieldDim) {
|
||||
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
|
||||
provider.fieldDim, len(item.Embedding))
|
||||
}
|
||||
data = append(data, item.Embedding)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/openai"
|
||||
)
|
||||
|
||||
func TestOpenAITextEmbeddingProvider(t *testing.T) {
|
||||
suite.Run(t, new(OpenAITextEmbeddingProviderSuite))
|
||||
}
|
||||
|
||||
type OpenAITextEmbeddingProviderSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *OpenAITextEmbeddingProviderSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.providers = []string{openAIProvider, azureOpenAIProvider}
|
||||
}
|
||||
|
||||
func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
},
|
||||
}
|
||||
switch providerName {
|
||||
case openAIProvider:
|
||||
return NewOpenAIEmbeddingProvider(schema, functionSchema)
|
||||
case azureOpenAIProvider:
|
||||
return NewAzureOpenAIEmbeddingProvider(schema, functionSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknow provider")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() {
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
s.NoError(err)
|
||||
{
|
||||
data := []string{"sentence"}
|
||||
ret, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(4, len(ret[0]))
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
|
||||
}
|
||||
{
|
||||
data := []string{"sentence 1", "sentence 2", "sentence 3"}
|
||||
ret, _ := provder.CallEmbedding(data, SearchMode)
|
||||
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res openai.EmbeddingResponse
|
||||
res.Object = "list"
|
||||
res.Model = "text-embedding-3-small"
|
||||
res.Data = append(res.Data, openai.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
Index: 0,
|
||||
})
|
||||
|
||||
res.Data = append(res.Data, openai.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.0, 1.0},
|
||||
Index: 1,
|
||||
})
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: 1,
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res openai.EmbeddingResponse
|
||||
res.Object = "list"
|
||||
res.Model = "text-embedding-3-small"
|
||||
res.Data = append(res.Data, openai.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
Index: 0,
|
||||
})
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: 1,
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence2"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAITextEmbeddingProviderSuite) TestCreateOpenAIEmbeddingClient() {
|
||||
_, err := createOpenAIEmbeddingClient("", "")
|
||||
s.Error(err)
|
||||
|
||||
os.Setenv(openaiAKEnvStr, "mockKey")
|
||||
defer os.Unsetenv(openaiAKEnvStr)
|
||||
|
||||
_, err = createOpenAIEmbeddingClient("", "")
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *OpenAITextEmbeddingProviderSuite) TestCreateAzureOpenAIEmbeddingClient() {
|
||||
_, err := createAzureOpenAIEmbeddingClient("", "")
|
||||
s.Error(err)
|
||||
|
||||
os.Setenv(azureOpenaiAKEnvStr, "mockKey")
|
||||
defer os.Unsetenv(azureOpenaiAKEnvStr)
|
||||
|
||||
_, err = createAzureOpenAIEmbeddingClient("", "")
|
||||
s.Error(err)
|
||||
|
||||
os.Setenv(azureOpenaiResourceName, "mockResource")
|
||||
defer os.Unsetenv(azureOpenaiResourceName)
|
||||
|
||||
_, err = createAzureOpenAIEmbeddingClient("", "")
|
||||
s.NoError(err)
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
)
|
||||
|
||||
const (
|
||||
Provider string = "provider"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIProvider string = "openai"
|
||||
azureOpenAIProvider string = "azure_openai"
|
||||
aliDashScopeProvider string = "dashscope"
|
||||
bedrockProvider string = "bedrock"
|
||||
vertexAIProvider string = "vertexai"
|
||||
voyageAIProvider string = "voyageai"
|
||||
)
|
||||
|
||||
// Text embedding for retrieval task
|
||||
type textEmbeddingProvider interface {
|
||||
MaxBatch() int
|
||||
CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error)
|
||||
FieldDim() int64
|
||||
}
|
||||
|
||||
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case Provider:
|
||||
return strings.ToLower(param.Value), nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
|
||||
}
|
||||
|
||||
type TextEmbeddingFunction struct {
|
||||
FunctionBase
|
||||
|
||||
embProvider textEmbeddingProvider
|
||||
}
|
||||
|
||||
func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmbeddingFunction, error) {
|
||||
if len(functionSchema.GetOutputFieldNames()) != 1 {
|
||||
return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames()))
|
||||
}
|
||||
|
||||
base, err := NewFunctionBase(coll, functionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if base.outputFields[0].DataType != schemapb.DataType_FloatVector {
|
||||
return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s], got [%s]",
|
||||
schemapb.DataType_name[int32(schemapb.DataType_FloatVector)],
|
||||
schemapb.DataType_name[int32(base.outputFields[0].DataType)])
|
||||
}
|
||||
|
||||
provider, err := getProvider(functionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch provider {
|
||||
case openAIProvider:
|
||||
embP, err := NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TextEmbeddingFunction{
|
||||
FunctionBase: *base,
|
||||
embProvider: embP,
|
||||
}, nil
|
||||
case azureOpenAIProvider:
|
||||
embP, err := NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TextEmbeddingFunction{
|
||||
FunctionBase: *base,
|
||||
embProvider: embP,
|
||||
}, nil
|
||||
case bedrockProvider:
|
||||
embP, err := NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TextEmbeddingFunction{
|
||||
FunctionBase: *base,
|
||||
embProvider: embP,
|
||||
}, nil
|
||||
case aliDashScopeProvider:
|
||||
embP, err := NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TextEmbeddingFunction{
|
||||
FunctionBase: *base,
|
||||
embProvider: embP,
|
||||
}, nil
|
||||
case vertexAIProvider:
|
||||
embP, err := NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TextEmbeddingFunction{
|
||||
FunctionBase: *base,
|
||||
embProvider: embP,
|
||||
}, nil
|
||||
case voyageAIProvider:
|
||||
embP, err := NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TextEmbeddingFunction{
|
||||
FunctionBase: *base,
|
||||
embProvider: embP,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) MaxBatch() int {
|
||||
return runner.embProvider.MaxBatch()
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
|
||||
if len(inputs) != 1 {
|
||||
return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs))
|
||||
}
|
||||
|
||||
if inputs[0].Type != schemapb.DataType_VarChar {
|
||||
return nil, fmt.Errorf("Text embedding only supports varchar field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)])
|
||||
}
|
||||
|
||||
texts := inputs[0].GetScalars().GetStringData().GetData()
|
||||
if texts == nil {
|
||||
return nil, fmt.Errorf("Input texts is empty")
|
||||
}
|
||||
numRows := len(texts)
|
||||
if numRows > runner.MaxBatch() {
|
||||
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
||||
}
|
||||
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
|
||||
for _, emb := range embds {
|
||||
data = append(data, emb...)
|
||||
}
|
||||
|
||||
var outputField schemapb.FieldData
|
||||
outputField.FieldId = runner.outputFields[0].FieldID
|
||||
outputField.FieldName = runner.outputFields[0].Name
|
||||
outputField.Type = runner.outputFields[0].DataType
|
||||
outputField.IsDynamic = runner.outputFields[0].IsDynamic
|
||||
outputField.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
Dim: runner.embProvider.FieldDim(),
|
||||
},
|
||||
}
|
||||
return []*schemapb.FieldData{&outputField}, nil
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
|
||||
texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally
|
||||
numRows := len(texts)
|
||||
if numRows > runner.MaxBatch() {
|
||||
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
||||
}
|
||||
embds, err := runner.embProvider.CallEmbedding(texts, SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return funcutil.Float32VectorsToPlaceholderGroup(embds), nil
|
||||
}
|
||||
|
||||
func (runner *TextEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) {
|
||||
if len(inputs) != 1 {
|
||||
return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs))
|
||||
}
|
||||
|
||||
if inputs[0].GetDataType() != schemapb.DataType_VarChar {
|
||||
return nil, fmt.Errorf(" only supports varchar field, the input is not varchar")
|
||||
}
|
||||
|
||||
texts, ok := inputs[0].GetDataRows().([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Input texts is empty")
|
||||
}
|
||||
|
||||
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
|
||||
for _, emb := range embds {
|
||||
data = append(data, emb...)
|
||||
}
|
||||
|
||||
field := &storage.FloatVectorFieldData{
|
||||
Data: data,
|
||||
Dim: int(runner.embProvider.FieldDim()),
|
||||
}
|
||||
return map[storage.FieldID]storage.FieldData{
|
||||
runner.outputFields[0].FieldID: field,
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,633 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/testutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
)
|
||||
|
||||
func TestTextEmbeddingFunction(t *testing.T) {
|
||||
suite.Run(t, new(TextEmbeddingFunctionSuite))
|
||||
}
|
||||
|
||||
type TextEmbeddingFunctionSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createData(texts []string) []*schemapb.FieldData {
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: texts,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestInvalidProvider() {
|
||||
fSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
providerName, err := getProvider(fSchema)
|
||||
s.Equal(providerName, openAIProvider)
|
||||
s.NoError(err)
|
||||
|
||||
fSchema.Params = []*commonpb.KeyValuePair{}
|
||||
providerName, err = getProvider(fSchema)
|
||||
s.Equal(providerName, "")
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestProcessInsert() {
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
{
|
||||
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
{
|
||||
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
|
||||
ret, _ := runner.ProcessInsert(data)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
}
|
||||
{
|
||||
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: azureOpenAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
{
|
||||
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
|
||||
ret, _ := runner.ProcessInsert(data)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() {
|
||||
ts := CreateAliEmbeddingServer()
|
||||
defer ts.Close()
|
||||
|
||||
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: aliDashScopeProvider},
|
||||
{Key: modelNameParamKey, Value: TextEmbeddingV3},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
{
|
||||
data := createData([]string{"sentence"})
|
||||
ret, err2 := runner.ProcessInsert(data)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(int64(4), ret[0].GetVectors().Dim)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
{
|
||||
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
|
||||
ret, _ := runner.ProcessInsert(data)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
|
||||
// multi-input
|
||||
{
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// wrong input data type
|
||||
{
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int32,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
s.Error(err)
|
||||
}
|
||||
// empty input
|
||||
{
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
s.Error(err)
|
||||
}
|
||||
// large input data
|
||||
{
|
||||
data := []*schemapb.FieldData{}
|
||||
f := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strings.Split(strings.Repeat("Element,", 1000), ","),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data = append(data, &f)
|
||||
_, err := runner.ProcessInsert(data)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() {
|
||||
// outputfield datatype mismatch
|
||||
{
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// outputfield number mismatc
|
||||
{
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector", "vector2"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102, 103},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// outputfield miss
|
||||
{
|
||||
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector2"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{103},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// error model name
|
||||
{
|
||||
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-004"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// no openai api key
|
||||
{
|
||||
_, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-003"},
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestNewTextEmbeddings() {
|
||||
{
|
||||
fSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: bedrockProvider},
|
||||
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
|
||||
{Key: awsAKIdParamKey, Value: "mock"},
|
||||
{Key: awsSAKParamKey, Value: "mock"},
|
||||
{Key: regionParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.NoError(err)
|
||||
fSchema.Params = []*commonpb.KeyValuePair{}
|
||||
_, err = NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
{
|
||||
fSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: aliDashScopeProvider},
|
||||
{Key: modelNameParamKey, Value: TextEmbeddingV1},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.NoError(err)
|
||||
fSchema.Params = []*commonpb.KeyValuePair{}
|
||||
_, err = NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
{
|
||||
fSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: voyageAIProvider},
|
||||
{Key: modelNameParamKey, Value: voyage3},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.NoError(err)
|
||||
fSchema.Params = []*commonpb.KeyValuePair{}
|
||||
_, err = NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Invalid params
|
||||
{
|
||||
fSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{},
|
||||
}
|
||||
|
||||
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
{
|
||||
fSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: "unkownProvider"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewTextEmbeddingFunction(s.schema, fSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestProcessSearch() {
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// Large inputs
|
||||
{
|
||||
f := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strings.Split(strings.Repeat("Element,", 1000), ","),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
|
||||
s.NoError(err)
|
||||
placeholderGroup := commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
|
||||
_, err = runner.ProcessSearch(&placeholderGroup)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Normal inputs
|
||||
{
|
||||
f := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 101,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strings.Split(strings.Repeat("Element,", 100), ","),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(f)
|
||||
s.NoError(err)
|
||||
placeholderGroup := commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(placeholderGroupBytes, &placeholderGroup)
|
||||
_, err = runner.ProcessSearch(&placeholderGroup)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TextEmbeddingFunctionSuite) TestProcessBulkInsert() {
|
||||
ts := CreateOpenAIEmbeddingServer()
|
||||
defer ts.Close()
|
||||
runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_TextEmbedding,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Provider, Value: openAIProvider},
|
||||
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: ts.URL},
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
data, err := testutil.CreateInsertData(s.schema, 100)
|
||||
s.NoError(err)
|
||||
{
|
||||
input := []storage.FieldData{data.Data[101]}
|
||||
_, err := runner.ProcessBulkInsert(input)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Multi-input
|
||||
{
|
||||
input := []storage.FieldData{data.Data[101], data.Data[101]}
|
||||
_, err := runner.ProcessBulkInsert(input)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Error input type
|
||||
{
|
||||
input := []storage.FieldData{data.Data[102]}
|
||||
_, err := runner.ProcessBulkInsert(input)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type vertexAIJsonKey struct {
|
||||
jsonKey []byte
|
||||
once sync.Once
|
||||
initErr error
|
||||
}
|
||||
|
||||
var vtxKey vertexAIJsonKey
|
||||
|
||||
func getVertexAIJsonKey() ([]byte, error) {
|
||||
vtxKey.once.Do(func() {
|
||||
jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv)
|
||||
jsonKey, err := os.ReadFile(jsonKeyPath)
|
||||
if err != nil {
|
||||
vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err)
|
||||
return
|
||||
}
|
||||
vtxKey.jsonKey = jsonKey
|
||||
})
|
||||
return vtxKey.jsonKey, vtxKey.initErr
|
||||
}
|
||||
|
||||
const (
|
||||
vertexAIDocRetrival string = "DOC_RETRIEVAL"
|
||||
vertexAICodeRetrival string = "CODE_RETRIEVAL"
|
||||
vertexAISTS string = "STS"
|
||||
)
|
||||
|
||||
func checkTask(modelName string, task string) error {
|
||||
if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS {
|
||||
return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS)
|
||||
}
|
||||
if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival {
|
||||
return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type VertexAIEmbeddingProvider struct {
|
||||
fieldDim int64
|
||||
|
||||
client *vertexai.VertexAIEmbedding
|
||||
modelName string
|
||||
embedDimParam int64
|
||||
task string
|
||||
|
||||
maxBatch int
|
||||
timeoutSec int64
|
||||
}
|
||||
|
||||
func createVertexAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) {
|
||||
jsonKey, err := getVertexAIJsonKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := vertexai.NewVertexAIEmbedding(url, jsonKey, "https://www.googleapis.com/auth/cloud-platform", "")
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewVertexAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertexAIEmbeddingProvider, error) {
|
||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var location, projectID, task, modelName string
|
||||
var dim int64
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case dimParamKey:
|
||||
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case locationParamKey:
|
||||
location = param.Value
|
||||
case projectIDParamKey:
|
||||
projectID = param.Value
|
||||
case taskTypeParamKey:
|
||||
task = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if task == "" {
|
||||
task = vertexAIDocRetrival
|
||||
}
|
||||
if err := checkTask(modelName, task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if location == "" {
|
||||
location = "us-central1"
|
||||
}
|
||||
|
||||
if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]",
|
||||
modelName, textEmbedding005, textMultilingualEmbedding002)
|
||||
}
|
||||
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName)
|
||||
var client *vertexai.VertexAIEmbedding
|
||||
if c == nil {
|
||||
client, err = createVertexAIEmbeddingClient(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
client = c
|
||||
}
|
||||
|
||||
provider := VertexAIEmbeddingProvider{
|
||||
fieldDim: fieldDim,
|
||||
client: client,
|
||||
modelName: modelName,
|
||||
embedDimParam: dim,
|
||||
task: task,
|
||||
maxBatch: 128,
|
||||
timeoutSec: 30,
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
func (provider *VertexAIEmbeddingProvider) MaxBatch() int {
|
||||
return 5 * provider.maxBatch
|
||||
}
|
||||
|
||||
func (provider *VertexAIEmbeddingProvider) FieldDim() int64 {
|
||||
return provider.fieldDim
|
||||
}
|
||||
|
||||
func (provider *VertexAIEmbeddingProvider) getTaskType(mode TextEmbeddingMode) string {
|
||||
if mode == SearchMode {
|
||||
switch provider.task {
|
||||
case vertexAIDocRetrival:
|
||||
return "RETRIEVAL_QUERY"
|
||||
case vertexAICodeRetrival:
|
||||
return "CODE_RETRIEVAL_QUERY"
|
||||
case vertexAISTS:
|
||||
return "SEMANTIC_SIMILARITY"
|
||||
}
|
||||
} else {
|
||||
switch provider.task {
|
||||
case vertexAIDocRetrival:
|
||||
return "RETRIEVAL_DOCUMENT"
|
||||
case vertexAICodeRetrival: // When inserting, the model does not distinguish between doc and code
|
||||
return "RETRIEVAL_DOCUMENT"
|
||||
case vertexAISTS:
|
||||
return "SEMANTIC_SIMILARITY"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (provider *VertexAIEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
|
||||
numRows := len(texts)
|
||||
taskType := provider.getTaskType(mode)
|
||||
data := make([][]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i += provider.maxBatch {
|
||||
end := i + provider.maxBatch
|
||||
if end > numRows {
|
||||
end = numRows
|
||||
}
|
||||
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, provider.timeoutSec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if end-i != len(resp.Predictions) {
|
||||
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Predictions))
|
||||
}
|
||||
for _, item := range resp.Predictions {
|
||||
if len(item.Embeddings.Values) != int(provider.fieldDim) {
|
||||
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
|
||||
provider.fieldDim, len(item.Embeddings.Values))
|
||||
}
|
||||
data = append(data, item.Embeddings.Values)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
|
@ -0,0 +1,276 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
|
||||
)
|
||||
|
||||
func TestVertexAITextEmbeddingProvider(t *testing.T) {
|
||||
suite.Run(t, new(VertexAITextEmbeddingProviderSuite))
|
||||
}
|
||||
|
||||
type VertexAITextEmbeddingProviderSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createVertexAIProvider(url string, schema *schemapb.FieldSchema) (textEmbeddingProvider, error) {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: textEmbedding005},
|
||||
{Key: locationParamKey, Value: "mock_local"},
|
||||
{Key: projectIDParamKey, Value: "mock_id"},
|
||||
{Key: taskTypeParamKey, Value: vertexAICodeRetrival},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token")
|
||||
return NewVertexAIEmbeddingProvider(schema, functionSchema, mockClient)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestEmbedding() {
|
||||
ts := CreateVertexAIEmbeddingServer()
|
||||
|
||||
defer ts.Close()
|
||||
provder, err := createVertexAIProvider(ts.URL, s.schema.Fields[2])
|
||||
s.NoError(err)
|
||||
{
|
||||
data := []string{"sentence"}
|
||||
ret, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(4, len(ret[0]))
|
||||
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
|
||||
}
|
||||
{
|
||||
data := []string{"sentence 1", "sentence 2", "sentence 3"}
|
||||
ret, _ := provder.CallEmbedding(data, SearchMode)
|
||||
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res vertexai.EmbeddingResponse
|
||||
res.Predictions = append(res.Predictions, vertexai.Prediction{
|
||||
Embeddings: vertexai.Embeddings{
|
||||
Statistics: vertexai.Statistics{
|
||||
Truncated: false,
|
||||
TokenCount: 10,
|
||||
},
|
||||
Values: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
},
|
||||
})
|
||||
res.Predictions = append(res.Predictions, vertexai.Prediction{
|
||||
Embeddings: vertexai.Embeddings{
|
||||
Statistics: vertexai.Statistics{
|
||||
Truncated: false,
|
||||
TokenCount: 10,
|
||||
},
|
||||
Values: []float32{1.0, 1.0},
|
||||
},
|
||||
})
|
||||
|
||||
res.Metadata = vertexai.Metadata{
|
||||
BillableCharacterCount: 100,
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
provder, err := createVertexAIProvider(ts.URL, s.schema.Fields[2])
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res vertexai.EmbeddingResponse
|
||||
res.Predictions = append(res.Predictions, vertexai.Prediction{
|
||||
Embeddings: vertexai.Embeddings{
|
||||
Statistics: vertexai.Statistics{
|
||||
Truncated: false,
|
||||
TokenCount: 10,
|
||||
},
|
||||
Values: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
},
|
||||
})
|
||||
res.Metadata = vertexai.Metadata{
|
||||
BillableCharacterCount: 100,
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
provder, err := createVertexAIProvider(ts.URL, s.schema.Fields[2])
|
||||
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence2"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestCheckVertexAITask() {
|
||||
err := checkTask(textMultilingualEmbedding002, "UnkownTask")
|
||||
s.Error(err)
|
||||
|
||||
// textMultilingualEmbedding002 not support vertexAICodeRetrival task
|
||||
err = checkTask(textMultilingualEmbedding002, vertexAICodeRetrival)
|
||||
s.Error(err)
|
||||
|
||||
err = checkTask(textEmbedding005, vertexAICodeRetrival)
|
||||
s.NoError(err)
|
||||
|
||||
err = checkTask(textMultilingualEmbedding002, vertexAISTS)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestGetVertexAIJsonKey() {
|
||||
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
||||
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
||||
_, err := getVertexAIJsonKey()
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestGetTaskType() {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: textEmbedding005},
|
||||
{Key: projectIDParamKey, Value: "mock_id"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
|
||||
|
||||
{
|
||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.NoError(err)
|
||||
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
|
||||
s.Equal(provider.getTaskType(SearchMode), "RETRIEVAL_QUERY")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAICodeRetrival})
|
||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.NoError(err)
|
||||
s.Equal(provider.getTaskType(InsertMode), "RETRIEVAL_DOCUMENT")
|
||||
s.Equal(provider.getTaskType(SearchMode), "CODE_RETRIEVAL_QUERY")
|
||||
}
|
||||
|
||||
{
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: vertexAISTS}
|
||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.NoError(err)
|
||||
s.Equal(provider.getTaskType(InsertMode), "SEMANTIC_SIMILARITY")
|
||||
s.Equal(provider.getTaskType(SearchMode), "SEMANTIC_SIMILARITY")
|
||||
}
|
||||
|
||||
// invalid task
|
||||
{
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: taskTypeParamKey, Value: "UnkownTask"}
|
||||
_, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestCreateVertexAIEmbeddingClient() {
|
||||
os.Setenv(vertexServiceAccountJSONEnv, "ErrorPath")
|
||||
defer os.Unsetenv(vertexServiceAccountJSONEnv)
|
||||
_, err := createVertexAIEmbeddingClient("https://mock_url.com")
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *VertexAITextEmbeddingProviderSuite) TestNewVertexAIEmbeddingProvider() {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: textEmbedding005},
|
||||
{Key: projectIDParamKey, Value: "mock_id"},
|
||||
{Key: dimParamKey, Value: "4"},
|
||||
},
|
||||
}
|
||||
mockClient := vertexai.NewVertexAIEmbedding("mock_url", []byte{1, 2, 3}, "mock scope", "mock token")
|
||||
provider, err := NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.NoError(err)
|
||||
s.True(provider.MaxBatch() > 0)
|
||||
s.Equal(provider.FieldDim(), int64(4))
|
||||
|
||||
// check model name
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
_, err = NewVertexAIEmbeddingProvider(s.schema.Fields[2], functionSchema, mockClient)
|
||||
s.Error(err)
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type VoyageAIEmbeddingProvider struct {
|
||||
fieldDim int64
|
||||
|
||||
client *voyageai.VoyageAIEmbedding
|
||||
modelName string
|
||||
embedDimParam int64
|
||||
|
||||
maxBatch int
|
||||
timeoutSec int64
|
||||
}
|
||||
|
||||
func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) {
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv(voyageAIAKEnvStr)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr)
|
||||
}
|
||||
|
||||
if url == "" {
|
||||
url = "https://api.voyageai.com/v1/embeddings"
|
||||
}
|
||||
|
||||
c := voyageai.NewVoyageAIEmbeddingClient(apiKey, url)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*VoyageAIEmbeddingProvider, error) {
|
||||
fieldDim, err := typeutil.GetDim(fieldSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var apiKey, url, modelName string
|
||||
dim := int64(0)
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case modelNameParamKey:
|
||||
modelName = param.Value
|
||||
case dimParamKey:
|
||||
// Only voyage-3-large and voyage-code-3 support dim param: 1024 (default), 256, 512, 2048
|
||||
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case apiKeyParamKey:
|
||||
apiKey = param.Value
|
||||
case embeddingURLParamKey:
|
||||
url = param.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if modelName != voyage3Large && modelName != voyage3 && modelName != voyage3Lite && modelName != voyageCode3 && modelName != voyageFinance2 && modelName != voyageLaw2 && modelName != voyageCode2 {
|
||||
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
|
||||
modelName, voyage3Large, voyage3, voyage3Lite, voyageCode3, voyageFinance2, voyageLaw2, voyageCode2)
|
||||
}
|
||||
|
||||
if dim != 0 {
|
||||
if modelName != voyage3Large && modelName != voyageCode3 {
|
||||
return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports dim parameter, only [%s, %s] support it.", modelName, voyage3, voyageCode3)
|
||||
}
|
||||
if dim != 1024 && dim != 256 && dim != 512 && dim != 2048 {
|
||||
return nil, fmt.Errorf("VoyageAI text embedding model's dim only supports 2048, 1024 (default), 512, and 256.")
|
||||
}
|
||||
}
|
||||
c, err := createVoyageAIEmbeddingClient(apiKey, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider := VoyageAIEmbeddingProvider{
|
||||
client: c,
|
||||
fieldDim: fieldDim,
|
||||
modelName: modelName,
|
||||
embedDimParam: dim,
|
||||
maxBatch: 128,
|
||||
timeoutSec: 30,
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
func (provider *VoyageAIEmbeddingProvider) MaxBatch() int {
|
||||
return 5 * provider.maxBatch
|
||||
}
|
||||
|
||||
func (provider *VoyageAIEmbeddingProvider) FieldDim() int64 {
|
||||
return provider.fieldDim
|
||||
}
|
||||
|
||||
func (provider *VoyageAIEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
|
||||
numRows := len(texts)
|
||||
var textType string
|
||||
if mode == InsertMode {
|
||||
textType = "document"
|
||||
} else {
|
||||
textType = "query"
|
||||
}
|
||||
|
||||
data := make([][]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i += provider.maxBatch {
|
||||
end := i + provider.maxBatch
|
||||
if end > numRows {
|
||||
end = numRows
|
||||
}
|
||||
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, "float", provider.timeoutSec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if end-i != len(resp.Data) {
|
||||
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data))
|
||||
}
|
||||
for _, item := range resp.Data {
|
||||
if len(item.Embedding) != int(provider.fieldDim) {
|
||||
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
|
||||
provider.fieldDim, len(item.Embedding))
|
||||
}
|
||||
data = append(data, item.Embedding)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
|
||||
)
|
||||
|
||||
func TestVoyageAITextEmbeddingProvider(t *testing.T) {
|
||||
suite.Run(t, new(VoyageAITextEmbeddingProviderSuite))
|
||||
}
|
||||
|
||||
type VoyageAITextEmbeddingProviderSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
providers []string
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "1024"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.providers = []string{voyageAIProvider}
|
||||
}
|
||||
|
||||
func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: voyage3Large},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: url},
|
||||
{Key: dimParamKey, Value: "1024"},
|
||||
},
|
||||
}
|
||||
switch providerName {
|
||||
case voyageAIProvider:
|
||||
return NewVoyageAIEmbeddingProvider(schema, functionSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknow provider")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbedding() {
|
||||
ts := CreateVoyageAIEmbeddingServer()
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
s.NoError(err)
|
||||
{
|
||||
data := []string{"sentence"}
|
||||
ret, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.NoError(err2)
|
||||
s.Equal(1, len(ret))
|
||||
s.Equal(1024, len(ret[0]))
|
||||
}
|
||||
{
|
||||
data := []string{"sentence 1", "sentence 2", "sentence 3"}
|
||||
_, err := provder.CallEmbedding(data, SearchMode)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res voyageai.EmbeddingResponse
|
||||
res.Data = append(res.Data, voyageai.EmbeddingData{
|
||||
Object: "list",
|
||||
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
Index: 0,
|
||||
})
|
||||
|
||||
res.Data = append(res.Data, voyageai.EmbeddingData{
|
||||
Object: "list",
|
||||
Embedding: []float32{1.0, 1.0},
|
||||
Index: 1,
|
||||
})
|
||||
res.Usage = voyageai.Usage{
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
for _, providerName := range s.providers {
|
||||
provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], providerName)
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var res voyageai.EmbeddingResponse
|
||||
res.Data = append(res.Data, voyageai.EmbeddingData{
|
||||
Object: "list",
|
||||
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
|
||||
Index: 0,
|
||||
})
|
||||
res.Usage = voyageai.Usage{
|
||||
TotalTokens: 100,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
}))
|
||||
|
||||
defer ts.Close()
|
||||
for _, provderName := range s.providers {
|
||||
provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], provderName)
|
||||
|
||||
s.NoError(err)
|
||||
|
||||
// embedding dim not match
|
||||
data := []string{"sentence", "sentence2"}
|
||||
_, err2 := provder.CallEmbedding(data, InsertMode)
|
||||
s.Error(err2)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) TestCreateVoyageAIEmbeddingClient() {
|
||||
_, err := createVoyageAIEmbeddingClient("", "")
|
||||
s.Error(err)
|
||||
|
||||
os.Setenv(voyageAIAKEnvStr, "mockKey")
|
||||
defer os.Unsetenv(voyageAIAKEnvStr)
|
||||
_, err = createVoyageAIEmbeddingClient("", "")
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *VoyageAITextEmbeddingProviderSuite) TestNewVoyageAIEmbeddingProvider() {
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Unknown,
|
||||
InputFieldNames: []string{"text"},
|
||||
OutputFieldNames: []string{"vector"},
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: modelNameParamKey, Value: voyage3Large},
|
||||
{Key: apiKeyParamKey, Value: "mock"},
|
||||
{Key: embeddingURLParamKey, Value: "mock"},
|
||||
{Key: dimParamKey, Value: "1024"},
|
||||
},
|
||||
}
|
||||
provider, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.NoError(err)
|
||||
s.Equal(provider.FieldDim(), int64(1024))
|
||||
s.True(provider.MaxBatch() > 0)
|
||||
|
||||
// Invalid model
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: modelNameParamKey, Value: "UnkownModel"}
|
||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Invalid dim
|
||||
{
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "9"}
|
||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Invalid dim type
|
||||
{
|
||||
functionSchema.Params[3] = &commonpb.KeyValuePair{Key: dimParamKey, Value: "Invalied"}
|
||||
_, err := NewVoyageAIEmbeddingProvider(s.schema.Fields[2], functionSchema)
|
||||
s.Error(err)
|
||||
}
|
||||
}
|
|
@ -25,6 +25,21 @@ func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte {
|
|||
return bytes
|
||||
}
|
||||
|
||||
func Float32VectorsToPlaceholderGroup(embs [][]float32) *commonpb.PlaceholderGroup {
|
||||
result := make([][]byte, 0, len(embs))
|
||||
for _, floatVector := range embs {
|
||||
result = append(result, floatVectorToByteVector(floatVector))
|
||||
}
|
||||
placeholderGroup := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: result,
|
||||
}},
|
||||
}
|
||||
return placeholderGroup
|
||||
}
|
||||
|
||||
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
|
||||
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue