mirror of https://github.com/milvus-io/milvus.git
				
				
				
			
		
			
				
	
	
		
			371 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			371 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
// Licensed to the LF AI & Data foundation under one
 | 
						|
// or more contributor license agreements. See the NOTICE file
 | 
						|
// distributed with this work for additional information
 | 
						|
// regarding copyright ownership. The ASF licenses this file
 | 
						|
// to you under the Apache License, Version 2.0 (the
 | 
						|
// "License"); you may not use this file except in compliance
 | 
						|
// with the License. You may obtain a copy of the License at
 | 
						|
//
 | 
						|
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
//
 | 
						|
// Unless required by applicable law or agreed to in writing, software
 | 
						|
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
// See the License for the specific language governing permissions and
 | 
						|
// limitations under the License.
 | 
						|
 | 
						|
package client
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"math/rand"
 | 
						|
	"testing"
 | 
						|
 | 
						|
	"github.com/samber/lo"
 | 
						|
	"github.com/stretchr/testify/mock"
 | 
						|
	"github.com/stretchr/testify/suite"
 | 
						|
 | 
						|
	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
 | 
						|
	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
 | 
						|
	"github.com/milvus-io/milvus/client/v2/entity"
 | 
						|
	"github.com/milvus-io/milvus/pkg/util/merr"
 | 
						|
)
 | 
						|
 | 
						|
type WriteSuite struct {
 | 
						|
	MockSuiteBase
 | 
						|
 | 
						|
	schema    *entity.Schema
 | 
						|
	schemaDyn *entity.Schema
 | 
						|
}
 | 
						|
 | 
						|
func (s *WriteSuite) SetupSuite() {
 | 
						|
	s.MockSuiteBase.SetupSuite()
 | 
						|
	s.schema = entity.NewSchema().
 | 
						|
		WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
 | 
						|
		WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
 | 
						|
 | 
						|
	s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
 | 
						|
		WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
 | 
						|
		WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
 | 
						|
}
 | 
						|
 | 
						|
func (s *WriteSuite) TestInsert() {
 | 
						|
	ctx, cancel := context.WithCancel(context.Background())
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	s.Run("success", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		partName := fmt.Sprintf("part_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schema)
 | 
						|
 | 
						|
		s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
 | 
						|
			s.Equal(collName, ir.GetCollectionName())
 | 
						|
			s.Equal(partName, ir.GetPartitionName())
 | 
						|
			s.Require().Len(ir.GetFieldsData(), 2)
 | 
						|
			s.EqualValues(3, ir.GetNumRows())
 | 
						|
			return &milvuspb.MutationResult{
 | 
						|
				Status:    merr.Success(),
 | 
						|
				InsertCnt: 3,
 | 
						|
				IDs: &schemapb.IDs{
 | 
						|
					IdField: &schemapb.IDs_IntId{
 | 
						|
						IntId: &schemapb.LongArray{
 | 
						|
							Data: []int64{1, 2, 3},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			}, nil
 | 
						|
		}).Once()
 | 
						|
 | 
						|
		result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
 | 
						|
			WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
				return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
			})).
 | 
						|
			WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
 | 
						|
		s.NoError(err)
 | 
						|
		s.EqualValues(3, result.InsertCount)
 | 
						|
	})
 | 
						|
 | 
						|
	s.Run("dynamic_schema", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		partName := fmt.Sprintf("part_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schemaDyn)
 | 
						|
 | 
						|
		s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
 | 
						|
			s.Equal(collName, ir.GetCollectionName())
 | 
						|
			s.Equal(partName, ir.GetPartitionName())
 | 
						|
			s.Require().Len(ir.GetFieldsData(), 3)
 | 
						|
			s.EqualValues(3, ir.GetNumRows())
 | 
						|
			return &milvuspb.MutationResult{
 | 
						|
				Status:    merr.Success(),
 | 
						|
				InsertCnt: 3,
 | 
						|
				IDs: &schemapb.IDs{
 | 
						|
					IdField: &schemapb.IDs_IntId{
 | 
						|
						IntId: &schemapb.LongArray{
 | 
						|
							Data: []int64{1, 2, 3},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			}, nil
 | 
						|
		}).Once()
 | 
						|
 | 
						|
		result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
 | 
						|
			WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
				return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
			})).
 | 
						|
			WithVarcharColumn("extra", []string{"a", "b", "c"}).
 | 
						|
			WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
 | 
						|
		s.NoError(err)
 | 
						|
		s.EqualValues(3, result.InsertCount)
 | 
						|
	})
 | 
						|
 | 
						|
	s.Run("bad_input", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schema)
 | 
						|
 | 
						|
		type badCase struct {
 | 
						|
			tag   string
 | 
						|
			input InsertOption
 | 
						|
		}
 | 
						|
 | 
						|
		cases := []badCase{
 | 
						|
			{
 | 
						|
				tag:   "missing_column",
 | 
						|
				input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag: "row_count_not_match",
 | 
						|
				input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
 | 
						|
					WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
						return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
					})),
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag: "duplicated_columns",
 | 
						|
				input: NewColumnBasedInsertOption(collName).
 | 
						|
					WithInt64Column("id", []int64{1}).
 | 
						|
					WithInt64Column("id", []int64{2}).
 | 
						|
					WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
 | 
						|
						return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
					})),
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag: "different_data_type",
 | 
						|
				input: NewColumnBasedInsertOption(collName).
 | 
						|
					WithVarcharColumn("id", []string{"1"}).
 | 
						|
					WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
 | 
						|
						return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
					})),
 | 
						|
			},
 | 
						|
		}
 | 
						|
 | 
						|
		for _, tc := range cases {
 | 
						|
			s.Run(tc.tag, func() {
 | 
						|
				_, err := s.client.Insert(ctx, tc.input)
 | 
						|
				s.Error(err)
 | 
						|
			})
 | 
						|
		}
 | 
						|
	})
 | 
						|
 | 
						|
	s.Run("failure", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schema)
 | 
						|
 | 
						|
		s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
 | 
						|
 | 
						|
		_, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
 | 
						|
			WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
				return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
			})).
 | 
						|
			WithInt64Column("id", []int64{1, 2, 3}))
 | 
						|
		s.Error(err)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func (s *WriteSuite) TestUpsert() {
 | 
						|
	ctx, cancel := context.WithCancel(context.Background())
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	s.Run("success", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		partName := fmt.Sprintf("part_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schema)
 | 
						|
 | 
						|
		s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
 | 
						|
			s.Equal(collName, ur.GetCollectionName())
 | 
						|
			s.Equal(partName, ur.GetPartitionName())
 | 
						|
			s.Require().Len(ur.GetFieldsData(), 2)
 | 
						|
			s.EqualValues(3, ur.GetNumRows())
 | 
						|
			return &milvuspb.MutationResult{
 | 
						|
				Status:    merr.Success(),
 | 
						|
				UpsertCnt: 3,
 | 
						|
				IDs: &schemapb.IDs{
 | 
						|
					IdField: &schemapb.IDs_IntId{
 | 
						|
						IntId: &schemapb.LongArray{
 | 
						|
							Data: []int64{1, 2, 3},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			}, nil
 | 
						|
		}).Once()
 | 
						|
 | 
						|
		result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
 | 
						|
			WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
				return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
			})).
 | 
						|
			WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
 | 
						|
		s.NoError(err)
 | 
						|
		s.EqualValues(3, result.UpsertCount)
 | 
						|
	})
 | 
						|
 | 
						|
	s.Run("dynamic_schema", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		partName := fmt.Sprintf("part_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schemaDyn)
 | 
						|
 | 
						|
		s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
 | 
						|
			s.Equal(collName, ur.GetCollectionName())
 | 
						|
			s.Equal(partName, ur.GetPartitionName())
 | 
						|
			s.Require().Len(ur.GetFieldsData(), 3)
 | 
						|
			s.EqualValues(3, ur.GetNumRows())
 | 
						|
			return &milvuspb.MutationResult{
 | 
						|
				Status:    merr.Success(),
 | 
						|
				UpsertCnt: 3,
 | 
						|
				IDs: &schemapb.IDs{
 | 
						|
					IdField: &schemapb.IDs_IntId{
 | 
						|
						IntId: &schemapb.LongArray{
 | 
						|
							Data: []int64{1, 2, 3},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			}, nil
 | 
						|
		}).Once()
 | 
						|
 | 
						|
		result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
 | 
						|
			WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
				return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
			})).
 | 
						|
			WithVarcharColumn("extra", []string{"a", "b", "c"}).
 | 
						|
			WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
 | 
						|
		s.NoError(err)
 | 
						|
		s.EqualValues(3, result.UpsertCount)
 | 
						|
	})
 | 
						|
 | 
						|
	s.Run("bad_input", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schema)
 | 
						|
 | 
						|
		type badCase struct {
 | 
						|
			tag   string
 | 
						|
			input UpsertOption
 | 
						|
		}
 | 
						|
 | 
						|
		cases := []badCase{
 | 
						|
			{
 | 
						|
				tag:   "missing_column",
 | 
						|
				input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag: "row_count_not_match",
 | 
						|
				input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
 | 
						|
					WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
						return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
					})),
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag: "duplicated_columns",
 | 
						|
				input: NewColumnBasedInsertOption(collName).
 | 
						|
					WithInt64Column("id", []int64{1}).
 | 
						|
					WithInt64Column("id", []int64{2}).
 | 
						|
					WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
 | 
						|
						return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
					})),
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag: "different_data_type",
 | 
						|
				input: NewColumnBasedInsertOption(collName).
 | 
						|
					WithVarcharColumn("id", []string{"1"}).
 | 
						|
					WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
 | 
						|
						return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
					})),
 | 
						|
			},
 | 
						|
		}
 | 
						|
 | 
						|
		for _, tc := range cases {
 | 
						|
			s.Run(tc.tag, func() {
 | 
						|
				_, err := s.client.Upsert(ctx, tc.input)
 | 
						|
				s.Error(err)
 | 
						|
			})
 | 
						|
		}
 | 
						|
	})
 | 
						|
 | 
						|
	s.Run("failure", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		s.setupCache(collName, s.schema)
 | 
						|
 | 
						|
		s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
 | 
						|
 | 
						|
		_, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
 | 
						|
			WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
 | 
						|
				return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
 | 
						|
			})).
 | 
						|
			WithInt64Column("id", []int64{1, 2, 3}))
 | 
						|
		s.Error(err)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func (s *WriteSuite) TestDelete() {
 | 
						|
	ctx, cancel := context.WithCancel(context.Background())
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	s.Run("success", func() {
 | 
						|
		collName := fmt.Sprintf("coll_%s", s.randString(6))
 | 
						|
		partName := fmt.Sprintf("part_%s", s.randString(6))
 | 
						|
 | 
						|
		type testCase struct {
 | 
						|
			tag        string
 | 
						|
			input      DeleteOption
 | 
						|
			expectExpr string
 | 
						|
		}
 | 
						|
 | 
						|
		cases := []testCase{
 | 
						|
			{
 | 
						|
				tag:        "raw_expr",
 | 
						|
				input:      NewDeleteOption(collName).WithPartition(partName).WithExpr("id > 100"),
 | 
						|
				expectExpr: "id > 100",
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag:        "int_ids",
 | 
						|
				input:      NewDeleteOption(collName).WithPartition(partName).WithInt64IDs("id", []int64{1, 2, 3}),
 | 
						|
				expectExpr: "id in [1,2,3]",
 | 
						|
			},
 | 
						|
			{
 | 
						|
				tag:        "str_ids",
 | 
						|
				input:      NewDeleteOption(collName).WithPartition(partName).WithStringIDs("id", []string{"a", "b", "c"}),
 | 
						|
				expectExpr: `id in ["a","b","c"]`,
 | 
						|
			},
 | 
						|
		}
 | 
						|
 | 
						|
		for _, tc := range cases {
 | 
						|
			s.Run(tc.tag, func() {
 | 
						|
				s.mock.EXPECT().Delete(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dr *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
 | 
						|
					s.Equal(collName, dr.GetCollectionName())
 | 
						|
					s.Equal(partName, dr.GetPartitionName())
 | 
						|
					s.Equal(tc.expectExpr, dr.GetExpr())
 | 
						|
					return &milvuspb.MutationResult{
 | 
						|
						Status:    merr.Success(),
 | 
						|
						DeleteCnt: 100,
 | 
						|
					}, nil
 | 
						|
				}).Once()
 | 
						|
				result, err := s.client.Delete(ctx, tc.input)
 | 
						|
				s.NoError(err)
 | 
						|
				s.EqualValues(100, result.DeleteCount)
 | 
						|
			})
 | 
						|
		}
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func TestWrite(t *testing.T) {
 | 
						|
	suite.Run(t, new(WriteSuite))
 | 
						|
}
 |