From 8c9dab5cf79031fe7cd0195507fe6a48e739f7ba Mon Sep 17 00:00:00 2001 From: congqixia Date: Sun, 24 Nov 2024 14:48:33 +0800 Subject: [PATCH] enhance: [GoSDK] Add Slice method for Vector Columns (#37951) Related to #37768 Signed-off-by: Congqi Xia --- client/column/sparse.go | 6 ++ client/column/sparse_test.go | 9 ++ client/column/vector.go | 31 +++++++ client/column/vector_test.go | 90 +++++++++++++++++++ client/go.mod | 2 +- client/go.sum | 4 +- .../milvusclient/mock_milvus_server_test.go | 59 ++++++++++++ 7 files changed, 198 insertions(+), 3 deletions(-) diff --git a/client/column/sparse.go b/client/column/sparse.go index df99694a6b..8b68ace541 100644 --- a/client/column/sparse.go +++ b/client/column/sparse.go @@ -45,3 +45,9 @@ func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData { vectors.Dim = int64(max.Dim()) return fd } + +func (c *ColumnSparseFloatVector) Slice(start, end int) Column { + return &ColumnSparseFloatVector{ + vectorBase: c.vectorBase.slice(start, end), + } +} diff --git a/client/column/sparse_test.go b/client/column/sparse_test.go index bd3b6c89fb..bcb9223714 100644 --- a/client/column/sparse_test.go +++ b/client/column/sparse_test.go @@ -91,4 +91,13 @@ func TestColumnSparseEmbedding(t *testing.T) { assert.Equal(t, v, getV) } }) + + t.Run("test_column_slice", func(t *testing.T) { + l := rand.Intn(columnLen) + sliced := column.Slice(0, l) + slicedColumn, ok := sliced.(*ColumnSparseFloatVector) + if assert.True(t, ok) { + assert.Equal(t, column.Data()[:l], slicedColumn.Data()) + } + }) } diff --git a/client/column/vector.go b/client/column/vector.go index 6e3a3e1dac..e6b90bd097 100644 --- a/client/column/vector.go +++ b/client/column/vector.go @@ -40,6 +40,13 @@ func (b *vectorBase[T]) FieldData() *schemapb.FieldData { return fd } +func (b *vectorBase[T]) slice(start, end int) *vectorBase[T] { + return &vectorBase[T]{ + genericColumnBase: b.genericColumnBase.slice(start, end), + dim: b.dim, + } +} + func newVectorBase[T entity.Vector](fieldName string, dim int, vectors []T, fieldType entity.FieldType) *vectorBase[T] { return &vectorBase[T]{ genericColumnBase: &genericColumnBase[T]{ @@ -78,6 +85,12 @@ func (c *ColumnFloatVector) AppendValue(i interface{}) error { return nil } +func (c *ColumnFloatVector) Slice(start, end int) Column { + return &ColumnFloatVector{ + vectorBase: c.vectorBase.slice(start, end), + } +} + /* binary vector */ type ColumnBinaryVector struct { @@ -105,6 +118,12 @@ func (c *ColumnBinaryVector) AppendValue(i interface{}) error { return nil } +func (c *ColumnBinaryVector) Slice(start, end int) Column { + return &ColumnBinaryVector{ + vectorBase: c.vectorBase.slice(start, end), + } +} + /* fp16 vector */ type ColumnFloat16Vector struct { @@ -132,6 +151,12 @@ func (c *ColumnFloat16Vector) AppendValue(i interface{}) error { return nil } +func (c *ColumnFloat16Vector) Slice(start, end int) Column { + return &ColumnFloat16Vector{ + vectorBase: c.vectorBase.slice(start, end), + } +} + type ColumnBFloat16Vector struct { *vectorBase[entity.BFloat16Vector] } @@ -156,3 +181,9 @@ func (c *ColumnBFloat16Vector) AppendValue(i interface{}) error { } return nil } + +func (c *ColumnBFloat16Vector) Slice(start, end int) Column { + return &ColumnBFloat16Vector{ + vectorBase: c.vectorBase.slice(start, end), + } +} diff --git a/client/column/vector_test.go b/client/column/vector_test.go index 6dfcae77b0..03e25ba5e0 100644 --- a/client/column/vector_test.go +++ b/client/column/vector_test.go @@ -162,6 +162,96 @@ func (s *VectorSuite) TestBasic() { }) } +func (s *VectorSuite) TestSlice() { + s.Run("float_vector", func() { + name := fmt.Sprintf("field_%d", rand.Intn(1000)) + n := 100 + dim := rand.Intn(10) + 2 + data := make([][]float32, 0, n) + for i := 0; i < n; i++ { + row := lo.RepeatBy(dim, func(i int) float32 { + return rand.Float32() + }) + data = append(data, row) + } + column := NewColumnFloatVector(name, dim, data) + + l := rand.Intn(n) + sliced := column.Slice(0, l) + slicedColumn, ok := sliced.(*ColumnFloatVector) + if s.True(ok) { + s.Equal(dim, slicedColumn.Dim()) + s.Equal(lo.Map(data[:l], func(row []float32, _ int) entity.FloatVector { return entity.FloatVector(row) }), slicedColumn.Data()) + } + }) + + s.Run("binary_vector", func() { + name := fmt.Sprintf("field_%d", rand.Intn(1000)) + n := 100 + dim := (rand.Intn(10) + 1) * 8 + data := make([][]byte, 0, n) + for i := 0; i < n; i++ { + row := lo.RepeatBy(dim/8, func(i int) byte { + return byte(rand.Intn(math.MaxUint8)) + }) + data = append(data, row) + } + column := NewColumnBinaryVector(name, dim, data) + + l := rand.Intn(n) + sliced := column.Slice(0, l) + slicedColumn, ok := sliced.(*ColumnBinaryVector) + if s.True(ok) { + s.Equal(dim, slicedColumn.Dim()) + s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.BinaryVector { return entity.BinaryVector(row) }), slicedColumn.Data()) + } + }) + + s.Run("fp16_vector", func() { + name := fmt.Sprintf("field_%d", rand.Intn(1000)) + n := 3 + dim := rand.Intn(10) + 1 + data := make([][]byte, 0, n) + for i := 0; i < n; i++ { + row := lo.RepeatBy(dim*2, func(i int) byte { + return byte(rand.Intn(math.MaxUint8)) + }) + data = append(data, row) + } + column := NewColumnFloat16Vector(name, dim, data) + + l := rand.Intn(n) + sliced := column.Slice(0, l) + slicedColumn, ok := sliced.(*ColumnFloat16Vector) + if s.True(ok) { + s.Equal(dim, slicedColumn.Dim()) + s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.Float16Vector { return entity.Float16Vector(row) }), slicedColumn.Data()) + } + }) + + s.Run("bf16_vector", func() { + name := fmt.Sprintf("field_%d", rand.Intn(1000)) + n := 3 + dim := rand.Intn(10) + 1 + data := make([][]byte, 0, n) + for i := 0; i < n; i++ { + row := lo.RepeatBy(dim*2, func(i int) byte { + return byte(rand.Intn(math.MaxUint8)) + }) + data = append(data, row) + } + column := NewColumnBFloat16Vector(name, dim, data) + + l := rand.Intn(n) + sliced := column.Slice(0, l) + slicedColumn, ok := sliced.(*ColumnBFloat16Vector) + if s.True(ok) { + s.Equal(dim, slicedColumn.Dim()) + s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.BFloat16Vector { return entity.BFloat16Vector(row) }), slicedColumn.Data()) + } + }) +} + func TestVectors(t *testing.T) { suite.Run(t, new(VectorSuite)) } diff --git a/client/go.mod b/client/go.mod index 60b3b0904e..46d606eed2 100644 --- a/client/go.mod +++ b/client/go.mod @@ -6,7 +6,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/cockroachdb/errors v1.9.1 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241108105827-266fb751b620 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb github.com/quasilyte/go-ruleguard/dsl v0.3.22 github.com/samber/lo v1.27.0 diff --git a/client/go.sum b/client/go.sum index 255d79dfcf..958199431f 100644 --- a/client/go.sum +++ b/client/go.sum @@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241108105827-266fb751b620 h1:0IWUDtDloift7cQHalhdjuVkL/3qSeiXFqR7MofZBkg= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241108105827-266fb751b620/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb h1:lMyIrG03agASB88AAwnk+NOU9V33lcBdtub/ZEv6IQU= github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb/go.mod h1:w5nu1Z318AvgWQrGUYXaqLeVLu4JvCS/oYhxqctOZvU= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= diff --git a/client/milvusclient/mock_milvus_server_test.go b/client/milvusclient/mock_milvus_server_test.go index 6e0b4da8e6..aa917cfc5a 100644 --- a/client/milvusclient/mock_milvus_server_test.go +++ b/client/milvusclient/mock_milvus_server_test.go @@ -4393,6 +4393,65 @@ func (_c *MilvusServiceServer_OperatePrivilegeGroup_Call) RunAndReturn(run func( return _c } +// OperatePrivilegeV2 provides a mock function with given fields: _a0, _a1 +func (_m *MilvusServiceServer) OperatePrivilegeV2(_a0 context.Context, _a1 *milvuspb.OperatePrivilegeV2Request) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for OperatePrivilegeV2") + } + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeV2Request) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeV2Request) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperatePrivilegeV2Request) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MilvusServiceServer_OperatePrivilegeV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperatePrivilegeV2' +type MilvusServiceServer_OperatePrivilegeV2_Call struct { + *mock.Call +} + +// OperatePrivilegeV2 is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.OperatePrivilegeV2Request +func (_e *MilvusServiceServer_Expecter) OperatePrivilegeV2(_a0 interface{}, _a1 interface{}) *MilvusServiceServer_OperatePrivilegeV2_Call { + return &MilvusServiceServer_OperatePrivilegeV2_Call{Call: _e.mock.On("OperatePrivilegeV2", _a0, _a1)} +} + +func (_c *MilvusServiceServer_OperatePrivilegeV2_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.OperatePrivilegeV2Request)) *MilvusServiceServer_OperatePrivilegeV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.OperatePrivilegeV2Request)) + }) + return _c +} + +func (_c *MilvusServiceServer_OperatePrivilegeV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MilvusServiceServer_OperatePrivilegeV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MilvusServiceServer_OperatePrivilegeV2_Call) RunAndReturn(run func(context.Context, *milvuspb.OperatePrivilegeV2Request) (*commonpb.Status, error)) *MilvusServiceServer_OperatePrivilegeV2_Call { + _c.Call.Return(run) + return _c +} + // OperateUserRole provides a mock function with given fields: _a0, _a1 func (_m *MilvusServiceServer) OperateUserRole(_a0 context.Context, _a1 *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1)