mirror of https://github.com/milvus-io/milvus.git
371 lines
12 KiB
Go
371 lines
12 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"testing"
|
|
|
|
"github.com/samber/lo"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
)
|
|
|
|
type WriteSuite struct {
|
|
MockSuiteBase
|
|
|
|
schema *entity.Schema
|
|
schemaDyn *entity.Schema
|
|
}
|
|
|
|
func (s *WriteSuite) SetupSuite() {
|
|
s.MockSuiteBase.SetupSuite()
|
|
s.schema = entity.NewSchema().
|
|
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
|
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
|
|
|
s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
|
|
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
|
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
|
}
|
|
|
|
func (s *WriteSuite) TestInsert() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("success", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collName, s.schema)
|
|
|
|
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
|
|
s.Equal(collName, ir.GetCollectionName())
|
|
s.Equal(partName, ir.GetPartitionName())
|
|
s.Require().Len(ir.GetFieldsData(), 2)
|
|
s.EqualValues(3, ir.GetNumRows())
|
|
return &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
InsertCnt: 3,
|
|
IDs: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})).
|
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
|
s.NoError(err)
|
|
s.EqualValues(3, result.InsertCount)
|
|
})
|
|
|
|
s.Run("dynamic_schema", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collName, s.schemaDyn)
|
|
|
|
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
|
|
s.Equal(collName, ir.GetCollectionName())
|
|
s.Equal(partName, ir.GetPartitionName())
|
|
s.Require().Len(ir.GetFieldsData(), 3)
|
|
s.EqualValues(3, ir.GetNumRows())
|
|
return &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
InsertCnt: 3,
|
|
IDs: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})).
|
|
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
|
s.NoError(err)
|
|
s.EqualValues(3, result.InsertCount)
|
|
})
|
|
|
|
s.Run("bad_input", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collName, s.schema)
|
|
|
|
type badCase struct {
|
|
tag string
|
|
input InsertOption
|
|
}
|
|
|
|
cases := []badCase{
|
|
{
|
|
tag: "missing_column",
|
|
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
|
|
},
|
|
{
|
|
tag: "row_count_not_match",
|
|
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})),
|
|
},
|
|
{
|
|
tag: "duplicated_columns",
|
|
input: NewColumnBasedInsertOption(collName).
|
|
WithInt64Column("id", []int64{1}).
|
|
WithInt64Column("id", []int64{2}).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})),
|
|
},
|
|
{
|
|
tag: "different_data_type",
|
|
input: NewColumnBasedInsertOption(collName).
|
|
WithVarcharColumn("id", []string{"1"}).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})),
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
s.Run(tc.tag, func() {
|
|
_, err := s.client.Insert(ctx, tc.input)
|
|
s.Error(err)
|
|
})
|
|
}
|
|
})
|
|
|
|
s.Run("failure", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collName, s.schema)
|
|
|
|
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
|
|
|
_, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})).
|
|
WithInt64Column("id", []int64{1, 2, 3}))
|
|
s.Error(err)
|
|
})
|
|
}
|
|
|
|
func (s *WriteSuite) TestUpsert() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("success", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collName, s.schema)
|
|
|
|
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
|
s.Equal(collName, ur.GetCollectionName())
|
|
s.Equal(partName, ur.GetPartitionName())
|
|
s.Require().Len(ur.GetFieldsData(), 2)
|
|
s.EqualValues(3, ur.GetNumRows())
|
|
return &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
UpsertCnt: 3,
|
|
IDs: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})).
|
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
|
s.NoError(err)
|
|
s.EqualValues(3, result.UpsertCount)
|
|
})
|
|
|
|
s.Run("dynamic_schema", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collName, s.schemaDyn)
|
|
|
|
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
|
s.Equal(collName, ur.GetCollectionName())
|
|
s.Equal(partName, ur.GetPartitionName())
|
|
s.Require().Len(ur.GetFieldsData(), 3)
|
|
s.EqualValues(3, ur.GetNumRows())
|
|
return &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
UpsertCnt: 3,
|
|
IDs: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})).
|
|
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
|
s.NoError(err)
|
|
s.EqualValues(3, result.UpsertCount)
|
|
})
|
|
|
|
s.Run("bad_input", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collName, s.schema)
|
|
|
|
type badCase struct {
|
|
tag string
|
|
input UpsertOption
|
|
}
|
|
|
|
cases := []badCase{
|
|
{
|
|
tag: "missing_column",
|
|
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
|
|
},
|
|
{
|
|
tag: "row_count_not_match",
|
|
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})),
|
|
},
|
|
{
|
|
tag: "duplicated_columns",
|
|
input: NewColumnBasedInsertOption(collName).
|
|
WithInt64Column("id", []int64{1}).
|
|
WithInt64Column("id", []int64{2}).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})),
|
|
},
|
|
{
|
|
tag: "different_data_type",
|
|
input: NewColumnBasedInsertOption(collName).
|
|
WithVarcharColumn("id", []string{"1"}).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})),
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
s.Run(tc.tag, func() {
|
|
_, err := s.client.Upsert(ctx, tc.input)
|
|
s.Error(err)
|
|
})
|
|
}
|
|
})
|
|
|
|
s.Run("failure", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collName, s.schema)
|
|
|
|
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
|
|
|
_, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
|
})).
|
|
WithInt64Column("id", []int64{1, 2, 3}))
|
|
s.Error(err)
|
|
})
|
|
}
|
|
|
|
func (s *WriteSuite) TestDelete() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("success", func() {
|
|
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partName := fmt.Sprintf("part_%s", s.randString(6))
|
|
|
|
type testCase struct {
|
|
tag string
|
|
input DeleteOption
|
|
expectExpr string
|
|
}
|
|
|
|
cases := []testCase{
|
|
{
|
|
tag: "raw_expr",
|
|
input: NewDeleteOption(collName).WithPartition(partName).WithExpr("id > 100"),
|
|
expectExpr: "id > 100",
|
|
},
|
|
{
|
|
tag: "int_ids",
|
|
input: NewDeleteOption(collName).WithPartition(partName).WithInt64IDs("id", []int64{1, 2, 3}),
|
|
expectExpr: "id in [1,2,3]",
|
|
},
|
|
{
|
|
tag: "str_ids",
|
|
input: NewDeleteOption(collName).WithPartition(partName).WithStringIDs("id", []string{"a", "b", "c"}),
|
|
expectExpr: `id in ["a","b","c"]`,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
s.Run(tc.tag, func() {
|
|
s.mock.EXPECT().Delete(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dr *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
|
|
s.Equal(collName, dr.GetCollectionName())
|
|
s.Equal(partName, dr.GetPartitionName())
|
|
s.Equal(tc.expectExpr, dr.GetExpr())
|
|
return &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
DeleteCnt: 100,
|
|
}, nil
|
|
}).Once()
|
|
result, err := s.client.Delete(ctx, tc.input)
|
|
s.NoError(err)
|
|
s.EqualValues(100, result.DeleteCount)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestWrite(t *testing.T) {
|
|
suite.Run(t, new(WriteSuite))
|
|
}
|