mirror of https://github.com/milvus-io/milvus.git
Add logic of deserilize rawData into InsertData
Signed-off-by: XuanYang-cn <xuan.yang@zilliz.com>pull/4973/head^2
parent
d5daa18392
commit
87521adfbd
|
@ -32,45 +32,45 @@ type FieldData interface{}
|
|||
|
||||
type BoolFieldData struct {
|
||||
NumRows int
|
||||
data []bool
|
||||
Data []bool
|
||||
}
|
||||
type Int8FieldData struct {
|
||||
NumRows int
|
||||
data []int8
|
||||
Data []int8
|
||||
}
|
||||
type Int16FieldData struct {
|
||||
NumRows int
|
||||
data []int16
|
||||
Data []int16
|
||||
}
|
||||
type Int32FieldData struct {
|
||||
NumRows int
|
||||
data []int32
|
||||
Data []int32
|
||||
}
|
||||
type Int64FieldData struct {
|
||||
NumRows int
|
||||
data []int64
|
||||
Data []int64
|
||||
}
|
||||
type FloatFieldData struct {
|
||||
NumRows int
|
||||
data []float32
|
||||
Data []float32
|
||||
}
|
||||
type DoubleFieldData struct {
|
||||
NumRows int
|
||||
data []float64
|
||||
Data []float64
|
||||
}
|
||||
type StringFieldData struct {
|
||||
NumRows int
|
||||
data []string
|
||||
Data []string
|
||||
}
|
||||
type BinaryVectorFieldData struct {
|
||||
NumRows int
|
||||
data []byte
|
||||
dim int
|
||||
Data []byte
|
||||
Dim int
|
||||
}
|
||||
type FloatVectorFieldData struct {
|
||||
NumRows int
|
||||
data []float32
|
||||
dim int
|
||||
Data []float32
|
||||
Dim int
|
||||
}
|
||||
|
||||
// system filed id:
|
||||
|
@ -101,7 +101,7 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique
|
|||
if !ok {
|
||||
return nil, errors.New("data doesn't contains timestamp field")
|
||||
}
|
||||
ts := timeFieldData.(Int64FieldData).data
|
||||
ts := timeFieldData.(Int64FieldData).Data
|
||||
|
||||
for _, field := range insertCodec.Schema.Schema.Fields {
|
||||
singleData := data.Data[field.FieldID]
|
||||
|
@ -117,30 +117,30 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique
|
|||
eventWriter.SetEndTimestamp(typeutil.Timestamp(ts[len(ts)-1]))
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_BOOL:
|
||||
err = eventWriter.AddBoolToPayload(singleData.(BoolFieldData).data)
|
||||
err = eventWriter.AddBoolToPayload(singleData.(BoolFieldData).Data)
|
||||
case schemapb.DataType_INT8:
|
||||
err = eventWriter.AddInt8ToPayload(singleData.(Int8FieldData).data)
|
||||
err = eventWriter.AddInt8ToPayload(singleData.(Int8FieldData).Data)
|
||||
case schemapb.DataType_INT16:
|
||||
err = eventWriter.AddInt16ToPayload(singleData.(Int16FieldData).data)
|
||||
err = eventWriter.AddInt16ToPayload(singleData.(Int16FieldData).Data)
|
||||
case schemapb.DataType_INT32:
|
||||
err = eventWriter.AddInt32ToPayload(singleData.(Int32FieldData).data)
|
||||
err = eventWriter.AddInt32ToPayload(singleData.(Int32FieldData).Data)
|
||||
case schemapb.DataType_INT64:
|
||||
err = eventWriter.AddInt64ToPayload(singleData.(Int64FieldData).data)
|
||||
err = eventWriter.AddInt64ToPayload(singleData.(Int64FieldData).Data)
|
||||
case schemapb.DataType_FLOAT:
|
||||
err = eventWriter.AddFloatToPayload(singleData.(FloatFieldData).data)
|
||||
err = eventWriter.AddFloatToPayload(singleData.(FloatFieldData).Data)
|
||||
case schemapb.DataType_DOUBLE:
|
||||
err = eventWriter.AddDoubleToPayload(singleData.(DoubleFieldData).data)
|
||||
err = eventWriter.AddDoubleToPayload(singleData.(DoubleFieldData).Data)
|
||||
case schemapb.DataType_STRING:
|
||||
for _, singleString := range singleData.(StringFieldData).data {
|
||||
for _, singleString := range singleData.(StringFieldData).Data {
|
||||
err = eventWriter.AddOneStringToPayload(singleString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
err = eventWriter.AddBinaryVectorToPayload(singleData.(BinaryVectorFieldData).data, singleData.(BinaryVectorFieldData).dim)
|
||||
err = eventWriter.AddBinaryVectorToPayload(singleData.(BinaryVectorFieldData).Data, singleData.(BinaryVectorFieldData).Dim)
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
err = eventWriter.AddFloatVectorToPayload(singleData.(FloatVectorFieldData).data, singleData.(FloatVectorFieldData).dim)
|
||||
err = eventWriter.AddFloatVectorToPayload(singleData.(FloatVectorFieldData).Data, singleData.(FloatVectorFieldData).Dim)
|
||||
default:
|
||||
return nil, errors.Errorf("undefined data type %d", field.DataType)
|
||||
}
|
||||
|
@ -201,11 +201,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
boolFieldData.data, err = eventReader.GetBoolFromPayload()
|
||||
boolFieldData.Data, err = eventReader.GetBoolFromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
boolFieldData.NumRows = len(boolFieldData.data)
|
||||
boolFieldData.NumRows = len(boolFieldData.Data)
|
||||
resultData.Data[fieldID] = boolFieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_INT8:
|
||||
|
@ -214,11 +214,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int8FieldData.data, err = eventReader.GetInt8FromPayload()
|
||||
int8FieldData.Data, err = eventReader.GetInt8FromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int8FieldData.NumRows = len(int8FieldData.data)
|
||||
int8FieldData.NumRows = len(int8FieldData.Data)
|
||||
resultData.Data[fieldID] = int8FieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_INT16:
|
||||
|
@ -227,11 +227,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int16FieldData.data, err = eventReader.GetInt16FromPayload()
|
||||
int16FieldData.Data, err = eventReader.GetInt16FromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int16FieldData.NumRows = len(int16FieldData.data)
|
||||
int16FieldData.NumRows = len(int16FieldData.Data)
|
||||
resultData.Data[fieldID] = int16FieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_INT32:
|
||||
|
@ -240,11 +240,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int32FieldData.data, err = eventReader.GetInt32FromPayload()
|
||||
int32FieldData.Data, err = eventReader.GetInt32FromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int32FieldData.NumRows = len(int32FieldData.data)
|
||||
int32FieldData.NumRows = len(int32FieldData.Data)
|
||||
resultData.Data[fieldID] = int32FieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_INT64:
|
||||
|
@ -253,11 +253,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int64FieldData.data, err = eventReader.GetInt64FromPayload()
|
||||
int64FieldData.Data, err = eventReader.GetInt64FromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
int64FieldData.NumRows = len(int64FieldData.data)
|
||||
int64FieldData.NumRows = len(int64FieldData.Data)
|
||||
resultData.Data[fieldID] = int64FieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_FLOAT:
|
||||
|
@ -266,11 +266,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
floatFieldData.data, err = eventReader.GetFloatFromPayload()
|
||||
floatFieldData.Data, err = eventReader.GetFloatFromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
floatFieldData.NumRows = len(floatFieldData.data)
|
||||
floatFieldData.NumRows = len(floatFieldData.Data)
|
||||
resultData.Data[fieldID] = floatFieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_DOUBLE:
|
||||
|
@ -279,11 +279,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
doubleFieldData.data, err = eventReader.GetDoubleFromPayload()
|
||||
doubleFieldData.Data, err = eventReader.GetDoubleFromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
doubleFieldData.NumRows = len(doubleFieldData.data)
|
||||
doubleFieldData.NumRows = len(doubleFieldData.Data)
|
||||
resultData.Data[fieldID] = doubleFieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_STRING:
|
||||
|
@ -302,7 +302,7 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
stringFieldData.data = append(stringFieldData.data, singleString)
|
||||
stringFieldData.Data = append(stringFieldData.Data, singleString)
|
||||
}
|
||||
resultData.Data[fieldID] = stringFieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
|
@ -312,11 +312,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
binaryVectorFieldData.data, binaryVectorFieldData.dim, err = eventReader.GetBinaryVectorFromPayload()
|
||||
binaryVectorFieldData.Data, binaryVectorFieldData.Dim, err = eventReader.GetBinaryVectorFromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
binaryVectorFieldData.NumRows = len(binaryVectorFieldData.data)
|
||||
binaryVectorFieldData.NumRows = len(binaryVectorFieldData.Data)
|
||||
resultData.Data[fieldID] = binaryVectorFieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
|
@ -325,11 +325,11 @@ func (insertCodec *InsertCodec) Deserialize(blobs []*Blob) (partitionID UniqueID
|
|||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
floatVectorFieldData.data, floatVectorFieldData.dim, err = eventReader.GetFloatVectorFromPayload()
|
||||
floatVectorFieldData.Data, floatVectorFieldData.Dim, err = eventReader.GetFloatVectorFromPayload()
|
||||
if err != nil {
|
||||
return -1, -1, nil, err
|
||||
}
|
||||
floatVectorFieldData.NumRows = len(floatVectorFieldData.data) / 8
|
||||
floatVectorFieldData.NumRows = len(floatVectorFieldData.Data) / 8
|
||||
resultData.Data[fieldID] = floatVectorFieldData
|
||||
insertCodec.readerCloseFunc = append(insertCodec.readerCloseFunc, readerClose(binlogReader))
|
||||
default:
|
||||
|
|
|
@ -112,49 +112,49 @@ func TestInsertCodec(t *testing.T) {
|
|||
Data: map[int64]FieldData{
|
||||
1: Int64FieldData{
|
||||
NumRows: 2,
|
||||
data: []int64{1, 2},
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
100: BoolFieldData{
|
||||
NumRows: 2,
|
||||
data: []bool{true, false},
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
101: Int8FieldData{
|
||||
NumRows: 2,
|
||||
data: []int8{1, 2},
|
||||
Data: []int8{1, 2},
|
||||
},
|
||||
102: Int16FieldData{
|
||||
NumRows: 2,
|
||||
data: []int16{1, 2},
|
||||
Data: []int16{1, 2},
|
||||
},
|
||||
103: Int32FieldData{
|
||||
NumRows: 2,
|
||||
data: []int32{1, 2},
|
||||
Data: []int32{1, 2},
|
||||
},
|
||||
104: Int64FieldData{
|
||||
NumRows: 2,
|
||||
data: []int64{1, 2},
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
105: FloatFieldData{
|
||||
NumRows: 2,
|
||||
data: []float32{1, 2},
|
||||
Data: []float32{1, 2},
|
||||
},
|
||||
106: DoubleFieldData{
|
||||
NumRows: 2,
|
||||
data: []float64{1, 2},
|
||||
Data: []float64{1, 2},
|
||||
},
|
||||
107: StringFieldData{
|
||||
NumRows: 2,
|
||||
data: []string{"1", "2"},
|
||||
Data: []string{"1", "2"},
|
||||
},
|
||||
108: BinaryVectorFieldData{
|
||||
NumRows: 8,
|
||||
data: []byte{0, 255, 0, 1, 0, 1, 0, 1},
|
||||
dim: 8,
|
||||
Data: []byte{0, 255, 0, 1, 0, 1, 0, 1},
|
||||
Dim: 8,
|
||||
},
|
||||
109: FloatVectorFieldData{
|
||||
NumRows: 1,
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
dim: 8,
|
||||
Data: []float32{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
Dim: 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -3,19 +3,27 @@ package writenode
|
|||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
// NOTE: start pulsar before test
|
||||
func TestDataSyncService_Start(t *testing.T) {
|
||||
newMeta()
|
||||
const ctxTimeInMillisecond = 200
|
||||
const closeWithDeadline = true
|
||||
var ctx context.Context
|
||||
|
@ -35,56 +43,104 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
|
||||
// test data generate
|
||||
const DIM = 16
|
||||
const N = 10
|
||||
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
// GOOSE TODO orgnize
|
||||
const DIM = 2
|
||||
const N = 1
|
||||
var rawData []byte
|
||||
for _, ele := range vec {
|
||||
|
||||
// Float vector
|
||||
var fvector = [DIM]float32{1, 2}
|
||||
for _, ele := range fvector {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
bs := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bs, 1)
|
||||
rawData = append(rawData, bs...)
|
||||
var records []*commonpb.Blob
|
||||
for i := 0; i < N; i++ {
|
||||
blob := &commonpb.Blob{
|
||||
Value: rawData,
|
||||
}
|
||||
records = append(records, blob)
|
||||
|
||||
// Binary vector
|
||||
var bvector = [2]byte{0, 255}
|
||||
for _, ele := range bvector {
|
||||
bs := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bs, uint32(ele))
|
||||
rawData = append(rawData, bs...)
|
||||
}
|
||||
|
||||
// Bool
|
||||
bb := make([]byte, 4)
|
||||
var fieldBool = false
|
||||
var fieldBoolInt uint32
|
||||
if fieldBool {
|
||||
fieldBoolInt = 1
|
||||
} else {
|
||||
fieldBoolInt = 0
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint32(bb, fieldBoolInt)
|
||||
rawData = append(rawData, bb...)
|
||||
|
||||
// int8
|
||||
var dataInt8 int8 = 100
|
||||
bint8 := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bint8, uint32(dataInt8))
|
||||
rawData = append(rawData, bint8...)
|
||||
|
||||
// int16
|
||||
var dataInt16 int16 = 200
|
||||
bint16 := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bint16, uint32(dataInt16))
|
||||
rawData = append(rawData, bint16...)
|
||||
|
||||
// int32
|
||||
var dataInt32 int32 = 300
|
||||
bint32 := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bint32, uint32(dataInt32))
|
||||
rawData = append(rawData, bint32...)
|
||||
|
||||
// int64
|
||||
var dataInt64 int64 = 300
|
||||
bint64 := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bint64, uint32(dataInt64))
|
||||
rawData = append(rawData, bint64...)
|
||||
|
||||
// float32
|
||||
var datafloat float32 = 1.1
|
||||
bfloat32 := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bfloat32, math.Float32bits(datafloat))
|
||||
rawData = append(rawData, bfloat32...)
|
||||
|
||||
// float64
|
||||
var datafloat64 float64 = 2.2
|
||||
bfloat64 := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bfloat64, math.Float64bits(datafloat64))
|
||||
rawData = append(rawData, bfloat64...)
|
||||
|
||||
timeRange := TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: math.MaxUint64,
|
||||
}
|
||||
|
||||
// messages generate
|
||||
const MSGLENGTH = 10
|
||||
const MSGLENGTH = 1
|
||||
insertMessages := make([]msgstream.TsMsg, 0)
|
||||
for i := 0; i < MSGLENGTH; i++ {
|
||||
var msg msgstream.TsMsg = &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{
|
||||
uint32(i), uint32(i),
|
||||
uint32(i),
|
||||
},
|
||||
},
|
||||
InsertRequest: internalPb.InsertRequest{
|
||||
MsgType: internalPb.MsgType_kInsert,
|
||||
ReqID: UniqueID(0),
|
||||
CollectionName: "collection0",
|
||||
CollectionName: "coll1",
|
||||
PartitionTag: "default",
|
||||
SegmentID: UniqueID(0),
|
||||
SegmentID: UniqueID(1),
|
||||
ChannelID: UniqueID(0),
|
||||
ProxyID: UniqueID(0),
|
||||
Timestamps: []Timestamp{Timestamp(i + 1000), Timestamp(i + 1000)},
|
||||
RowIDs: []UniqueID{UniqueID(i), UniqueID(i)},
|
||||
Timestamps: []Timestamp{Timestamp(i + 1000)},
|
||||
RowIDs: []UniqueID{UniqueID(i)},
|
||||
|
||||
RowData: []*commonpb.Blob{
|
||||
{Value: rawData},
|
||||
{Value: rawData},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -149,3 +205,152 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func newMeta() {
|
||||
ETCDAddr := Params.EtcdAddress
|
||||
MetaRootPath := Params.MetaRootPath
|
||||
|
||||
cli, _ := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{ETCDAddr},
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
kvClient := etcdkv.NewEtcdKV(cli, MetaRootPath)
|
||||
defer kvClient.Close()
|
||||
|
||||
sch := schemapb.CollectionSchema{
|
||||
Name: "col1",
|
||||
Description: "test collection",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "col1_f1",
|
||||
Description: "test collection filed 1",
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "2",
|
||||
},
|
||||
{
|
||||
Key: "col1_f1_tk2",
|
||||
Value: "col1_f1_tv2",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "col1_f1_ik1",
|
||||
Value: "col1_f1_iv1",
|
||||
},
|
||||
{
|
||||
Key: "col1_f1_ik2",
|
||||
Value: "col1_f1_iv2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "col1_f2",
|
||||
Description: "test collection filed 2",
|
||||
DataType: schemapb.DataType_VECTOR_BINARY,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "8",
|
||||
},
|
||||
{
|
||||
Key: "col1_f2_tk2",
|
||||
Value: "col1_f2_tv2",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "col1_f2_ik1",
|
||||
Value: "col1_f2_iv1",
|
||||
},
|
||||
{
|
||||
Key: "col1_f2_ik2",
|
||||
Value: "col1_f2_iv2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "col1_f3",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_BOOL,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "col1_f4",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_INT8,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "col1_f5",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_INT16,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "col1_f6",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "col1_f7",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_INT64,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 107,
|
||||
Name: "col1_f8",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 108,
|
||||
Name: "col1_f9",
|
||||
Description: "test collection filed 3",
|
||||
DataType: schemapb.DataType_DOUBLE,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
collection := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(1),
|
||||
Schema: &sch,
|
||||
CreateTime: Timestamp(1),
|
||||
SegmentIDs: make([]UniqueID, 0),
|
||||
PartitionTags: make([]string, 0),
|
||||
}
|
||||
|
||||
collBytes := proto.MarshalTextString(&collection)
|
||||
kvClient.Save("/collection/"+strconv.FormatInt(collection.ID, 10), collBytes)
|
||||
value, _ := kvClient.Load("/collection/1")
|
||||
fmt.Println("========value: ", value)
|
||||
|
||||
segSch := etcdpb.SegmentMeta{
|
||||
SegmentID: UniqueID(1),
|
||||
CollectionID: UniqueID(1),
|
||||
}
|
||||
segBytes := proto.MarshalTextString(&segSch)
|
||||
kvClient.Save("/segment/"+strconv.FormatInt(segSch.SegmentID, 10), segBytes)
|
||||
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ func (ddNode *ddNode) Name() string {
|
|||
}
|
||||
|
||||
func (ddNode *ddNode) Operate(in []*Msg) []*Msg {
|
||||
//fmt.Println("Do filterDmNode operation")
|
||||
//fmt.Println("Do filterDdNode operation")
|
||||
|
||||
if len(in) != 1 {
|
||||
log.Println("Invalid operate message input in ddNode, input length = ", len(in))
|
||||
|
|
|
@ -40,6 +40,7 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
|||
|
||||
var iMsg = insertMsg{
|
||||
insertMessages: make([]*msgstream.InsertMsg, 0),
|
||||
flushMessages: make([]*msgstream.FlushMsg, 0),
|
||||
timeRange: TimeRange{
|
||||
timestampMin: msgStreamMsg.TimestampMin(),
|
||||
timestampMax: msgStreamMsg.TimestampMax(),
|
||||
|
@ -53,7 +54,7 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
|
|||
iMsg.insertMessages = append(iMsg.insertMessages, resMsg)
|
||||
}
|
||||
case internalPb.MsgType_kFlush:
|
||||
iMsg.insertMessages = append(iMsg.insertMessages, msg.(*msgstream.InsertMsg))
|
||||
iMsg.flushMessages = append(iMsg.flushMessages, msg.(*msgstream.FlushMsg))
|
||||
// case internalPb.MsgType_kDelete:
|
||||
// dmMsg.deleteMessages = append(dmMsg.deleteMessages, (*msg).(*msgstream.DeleteTask))
|
||||
default:
|
||||
|
|
|
@ -1,31 +1,72 @@
|
|||
package writenode
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"math"
|
||||
"path"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/storage"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
const (
|
||||
CollectionPrefix = "/collection/"
|
||||
SegmentPrefix = "/segment/"
|
||||
)
|
||||
|
||||
type (
|
||||
InsertData = storage.InsertData
|
||||
Blob = storage.Blob
|
||||
|
||||
insertBufferNode struct {
|
||||
BaseNode
|
||||
binLogs map[SegmentID][]*storage.Blob // Binary logs of a segment.
|
||||
buffer *insertBuffer
|
||||
}
|
||||
|
||||
insertBufferData struct {
|
||||
logIdx int // TODO What's it for?
|
||||
partitionID UniqueID
|
||||
segmentID UniqueID
|
||||
data *storage.InsertData
|
||||
kvClient *etcdkv.EtcdKV
|
||||
insertBuffer *insertBuffer
|
||||
}
|
||||
|
||||
insertBuffer struct {
|
||||
buffer []*insertBufferData
|
||||
maxSize int // TODO set from write_node.yaml
|
||||
insertData map[UniqueID]*InsertData // SegmentID to InsertData
|
||||
maxSize int // GOOSE TODO set from write_node.yaml
|
||||
}
|
||||
)
|
||||
|
||||
func (ib *insertBuffer) size(segmentID UniqueID) int {
|
||||
if ib.insertData == nil || len(ib.insertData) <= 0 {
|
||||
return 0
|
||||
}
|
||||
idata, ok := ib.insertData[segmentID]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxSize := 0
|
||||
for _, data := range idata.Data {
|
||||
fdata, ok := data.(storage.FloatVectorFieldData)
|
||||
if ok && len(fdata.Data) > maxSize {
|
||||
maxSize = len(fdata.Data)
|
||||
}
|
||||
|
||||
bdata, ok := data.(storage.BinaryVectorFieldData)
|
||||
if ok && len(bdata.Data) > maxSize {
|
||||
maxSize = len(bdata.Data)
|
||||
}
|
||||
|
||||
}
|
||||
return maxSize
|
||||
}
|
||||
|
||||
func (ib *insertBuffer) full(segmentID UniqueID) bool {
|
||||
// GOOSE TODO
|
||||
return ib.size(segmentID) >= ib.maxSize
|
||||
}
|
||||
|
||||
func (ibNode *insertBufferNode) Name() string {
|
||||
return "ibNode"
|
||||
}
|
||||
|
@ -38,37 +79,212 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
|
|||
// TODO: add error handling
|
||||
}
|
||||
|
||||
_, ok := (*in[0]).(*insertMsg)
|
||||
iMsg, ok := (*in[0]).(*insertMsg)
|
||||
if !ok {
|
||||
log.Println("type assertion failed for insertMsg")
|
||||
// TODO: add error handling
|
||||
}
|
||||
for _, task := range iMsg.insertMessages {
|
||||
if len(task.RowIDs) != len(task.Timestamps) || len(task.RowIDs) != len(task.RowData) {
|
||||
log.Println("Error, misaligned messages detected")
|
||||
continue
|
||||
}
|
||||
|
||||
// iMsg is insertMsg
|
||||
// 1. iMsg -> insertBufferData -> insertBuffer
|
||||
// 2. Send hardTimeTick msg
|
||||
// 3. if insertBuffer full
|
||||
// 3.1 insertBuffer -> binLogs
|
||||
// 3.2 binLogs -> minIO/S3
|
||||
// iMsg is Flush() msg from master
|
||||
// 1. insertBuffer(not empty) -> binLogs -> minIO/S3
|
||||
// Return
|
||||
// iMsg is insertMsg
|
||||
// 1. iMsg -> binLogs -> buffer
|
||||
for _, msg := range iMsg.insertMessages {
|
||||
currentSegID := msg.GetSegmentID()
|
||||
|
||||
// log.Println("=========== insertMsg length:", len(iMsg.insertMessages))
|
||||
// for _, task := range iMsg.insertMessages {
|
||||
// if len(task.RowIDs) != len(task.Timestamps) || len(task.RowIDs) != len(task.RowData) {
|
||||
// log.Println("Error, misaligned messages detected")
|
||||
// continue
|
||||
// }
|
||||
// log.Println("Timestamp: ", task.Timestamps[0])
|
||||
// log.Printf("t(%d) : %v ", task.Timestamps[0], task.RowData[0])
|
||||
// }
|
||||
idata, ok := ibNode.insertBuffer.insertData[currentSegID]
|
||||
if !ok {
|
||||
idata = &InsertData{
|
||||
Data: make(map[UniqueID]storage.FieldData),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
idata.Data[1] = msg.BeginTimestamp
|
||||
|
||||
// 1.1 Get CollectionMeta from etcd
|
||||
// GOOSE TODO get meta from metaTable
|
||||
segMeta := etcdpb.SegmentMeta{}
|
||||
|
||||
key := path.Join(SegmentPrefix, strconv.FormatInt(currentSegID, 10))
|
||||
value, _ := ibNode.kvClient.Load(key)
|
||||
err := proto.UnmarshalText(value, &segMeta)
|
||||
if err != nil {
|
||||
log.Println("Load segMeta error")
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
collMeta := etcdpb.CollectionMeta{}
|
||||
key = path.Join(CollectionPrefix, strconv.FormatInt(segMeta.GetCollectionID(), 10))
|
||||
value, _ = ibNode.kvClient.Load(key)
|
||||
err = proto.UnmarshalText(value, &collMeta)
|
||||
if err != nil {
|
||||
log.Println("Load collMeta error")
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
// 1.2 Get Fields
|
||||
var pos = 0 // Record position of blob
|
||||
for _, field := range collMeta.Schema.Fields {
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
var dim int
|
||||
for _, t := range field.TypeParams {
|
||||
if t.Key == "dim" {
|
||||
dim, err = strconv.Atoi(t.Value)
|
||||
if err != nil {
|
||||
log.Println("strconv wrong")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if dim <= 0 {
|
||||
log.Println("invalid dim")
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
data := make([]float32, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
for j := pos; j < dim; j++ {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[j*4:])
|
||||
data = append(data, math.Float32frombits(v))
|
||||
pos++
|
||||
}
|
||||
}
|
||||
idata.Data[field.FieldID] = storage.FloatVectorFieldData{
|
||||
NumRows: len(msg.RowIDs),
|
||||
Data: data,
|
||||
Dim: dim,
|
||||
}
|
||||
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
// GOOSE TODO
|
||||
var dim int
|
||||
for _, t := range field.TypeParams {
|
||||
if t.Key == "dim" {
|
||||
dim, err = strconv.Atoi(t.Value)
|
||||
if err != nil {
|
||||
log.Println("strconv wrong")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if dim <= 0 {
|
||||
log.Println("invalid dim")
|
||||
// TODO: add error handling
|
||||
}
|
||||
|
||||
data := make([]byte, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
for d := 0; d < dim/4; d++ {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
data = append(data, byte(v))
|
||||
pos++
|
||||
}
|
||||
}
|
||||
idata.Data[field.FieldID] = storage.BinaryVectorFieldData{
|
||||
NumRows: len(data) * 8 / dim,
|
||||
Data: data,
|
||||
Dim: dim,
|
||||
}
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_BOOL:
|
||||
data := make([]bool, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
boolInt := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
if boolInt == 1 {
|
||||
data = append(data, true)
|
||||
} else {
|
||||
data = append(data, false)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_INT8:
|
||||
data := make([]int8, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
data = append(data, int8(v))
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_INT16:
|
||||
data := make([]int16, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
data = append(data, int16(v))
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_INT32:
|
||||
data := make([]int32, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
data = append(data, int32(v))
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_INT64:
|
||||
data := make([]int64, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
data = append(data, int64(v))
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_FLOAT:
|
||||
data := make([]float32, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
v := binary.LittleEndian.Uint32(blob.GetValue()[pos*4:])
|
||||
data = append(data, math.Float32frombits(v))
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
case schemapb.DataType_DOUBLE:
|
||||
// GOOSE TODO pos
|
||||
data := make([]float64, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
v := binary.LittleEndian.Uint64(blob.GetValue()[pos*4:])
|
||||
data = append(data, math.Float64frombits(v))
|
||||
pos++
|
||||
}
|
||||
idata.Data[field.FieldID] = data
|
||||
log.Println("aaaaaaaa", idata)
|
||||
}
|
||||
}
|
||||
|
||||
// 1.3 store in buffer
|
||||
ibNode.insertBuffer.insertData[currentSegID] = idata
|
||||
// 1.4 Send hardTimeTick msg
|
||||
|
||||
// 1.5 if full
|
||||
// 1.5.1 generate binlogs
|
||||
// GOOSE TODO partitionTag -> partitionID
|
||||
// 1.5.2 binLogs -> minIO/S3
|
||||
if ibNode.insertBuffer.full(currentSegID) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// iMsg is Flush() msg from master
|
||||
// 1. insertBuffer(not empty) -> binLogs -> minIO/S3
|
||||
// Return
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newInsertBufferNode() *insertBufferNode {
|
||||
|
||||
maxQueueLength := Params.FlowGraphMaxQueueLength
|
||||
maxParallelism := Params.FlowGraphMaxParallelism
|
||||
|
||||
|
@ -76,16 +292,26 @@ func newInsertBufferNode() *insertBufferNode {
|
|||
baseNode.SetMaxQueueLength(maxQueueLength)
|
||||
baseNode.SetMaxParallelism(maxParallelism)
|
||||
|
||||
// TODO read from yaml
|
||||
// GOOSE TODO maxSize read from yaml
|
||||
maxSize := 10
|
||||
iBuffer := &insertBuffer{
|
||||
buffer: make([]*insertBufferData, maxSize),
|
||||
maxSize: maxSize,
|
||||
insertData: make(map[UniqueID]*InsertData),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
|
||||
// EtcdKV
|
||||
ETCDAddr := Params.EtcdAddress
|
||||
MetaRootPath := Params.MetaRootPath
|
||||
log.Println("metaRootPath: ", MetaRootPath)
|
||||
cli, _ := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{ETCDAddr},
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
kvClient := etcdkv.NewEtcdKV(cli, MetaRootPath)
|
||||
|
||||
return &insertBufferNode{
|
||||
BaseNode: baseNode,
|
||||
binLogs: make(map[SegmentID][]*storage.Blob),
|
||||
buffer: iBuffer,
|
||||
BaseNode: baseNode,
|
||||
kvClient: kvClient,
|
||||
insertBuffer: iBuffer,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
type (
|
||||
Msg = flowgraph.Msg
|
||||
MsgStreamMsg = flowgraph.MsgStreamMsg
|
||||
SegmentID = UniqueID
|
||||
)
|
||||
|
||||
type (
|
||||
|
|
|
@ -40,28 +40,19 @@ func NewWriteNode(ctx context.Context, writeNodeID uint64) (*WriteNode, error) {
|
|||
|
||||
func (node *WriteNode) Start() {
|
||||
node.dataSyncService = newDataSyncService(node.ctx)
|
||||
// node.searchService = newSearchService(node.ctx)
|
||||
// node.metaService = newMetaService(node.ctx)
|
||||
// node.statsService = newStatsService(node.ctx)
|
||||
|
||||
go node.dataSyncService.start()
|
||||
// go node.searchService.start()
|
||||
// go node.metaService.start()
|
||||
// node.statsService.start()
|
||||
}
|
||||
|
||||
func (node *WriteNode) Close() {
|
||||
<-node.ctx.Done()
|
||||
// free collectionReplica
|
||||
// (*node.replica).freeAll()
|
||||
|
||||
// close services
|
||||
if node.dataSyncService != nil {
|
||||
(*node.dataSyncService).close()
|
||||
}
|
||||
// if node.searchService != nil {
|
||||
// (*node.searchService).close()
|
||||
// }
|
||||
// if node.statsService != nil {
|
||||
// (*node.statsService).close()
|
||||
// }
|
||||
|
|
|
@ -705,7 +705,8 @@ class TestSearchBase:
|
|||
# TODO:
|
||||
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_jaccard_flat_index")
|
||||
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
target: search binary_collection, and check the result: distance
|
||||
|
@ -739,7 +740,8 @@ class TestSearchBase:
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(binary_collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_hamming_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_hamming_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -756,7 +758,8 @@ class TestSearchBase:
|
|||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_substructure_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_substructure_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -774,7 +777,8 @@ class TestSearchBase:
|
|||
res = connect.search(binary_collection, query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_substructure_flat_index_B")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -793,7 +797,8 @@ class TestSearchBase:
|
|||
assert res[1][0].distance <= epsilon
|
||||
assert res[1][0].id == ids[1]
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_superstructure_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -811,7 +816,8 @@ class TestSearchBase:
|
|||
res = connect.search(binary_collection, query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_superstructure_flat_index_B")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -832,7 +838,8 @@ class TestSearchBase:
|
|||
assert res[1][0].id in ids
|
||||
assert res[1][0].distance <= epsilon
|
||||
|
||||
# PASS
|
||||
# DOG: TODO BINARY
|
||||
@pytest.mark.skip("search_distance_tanimoto_flat_index")
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -970,7 +977,8 @@ class TestSearchDSL(object):
|
|||
******************************************************************
|
||||
"""
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_no_must")
|
||||
def test_query_no_must(self, connect, collection):
|
||||
'''
|
||||
method: build query without must expr
|
||||
|
@ -981,7 +989,8 @@ class TestSearchDSL(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_no_vector_term_only")
|
||||
def test_query_no_vector_term_only(self, connect, collection):
|
||||
'''
|
||||
method: build query without vector only term
|
||||
|
@ -1016,7 +1025,8 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_wrong_format")
|
||||
def test_query_wrong_format(self, connect, collection):
|
||||
'''
|
||||
method: build query without must expr, with wrong expr name
|
||||
|
@ -1158,7 +1168,8 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
@pytest.mark.skip("query_complex_dsl")
|
||||
def test_query_complex_dsl(self, connect, collection):
|
||||
'''
|
||||
method: query with complicated dsl
|
||||
|
@ -1180,7 +1191,9 @@ class TestSearchDSL(object):
|
|||
******************************************************************
|
||||
"""
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
# TODO
|
||||
@pytest.mark.skip("query_term_key_error")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_key_error(self, connect, collection):
|
||||
'''
|
||||
|
@ -1200,7 +1213,8 @@ class TestSearchDSL(object):
|
|||
def get_invalid_term(self, request):
|
||||
return request.param
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_term_wrong_format")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
|
||||
'''
|
||||
|
@ -1214,7 +1228,7 @@ class TestSearchDSL(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# DOG: PLEASE IMPLEMENT connect.count_entities
|
||||
# DOG: TODO UNKNOWN
|
||||
# TODO
|
||||
@pytest.mark.skip("query_term_field_named_term")
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1230,8 +1244,8 @@ class TestSearchDSL(object):
|
|||
ids = connect.bulk_insert(collection_term, term_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection_term])
|
||||
count = connect.count_entities(collection_term) # count_entities is not impelmented
|
||||
assert count == default_nb # removing these two lines, this test passed
|
||||
count = connect.count_entities(collection_term)
|
||||
assert count == default_nb
|
||||
term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}}
|
||||
expr = {"must": [gen_default_vector_expr(default_query),
|
||||
term_param]}
|
||||
|
@ -1241,7 +1255,8 @@ class TestSearchDSL(object):
|
|||
assert len(res[0]) == default_top_k
|
||||
connect.drop_collection(collection_term)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_term_one_field_not_existed")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_one_field_not_existed(self, connect, collection):
|
||||
'''
|
||||
|
@ -1263,6 +1278,7 @@ class TestSearchDSL(object):
|
|||
"""
|
||||
|
||||
# PASS
|
||||
# TODO
|
||||
def test_query_range_key_error(self, connect, collection):
|
||||
'''
|
||||
method: build query with range key error
|
||||
|
@ -1282,6 +1298,7 @@ class TestSearchDSL(object):
|
|||
return request.param
|
||||
|
||||
# PASS
|
||||
# TODO
|
||||
@pytest.mark.level(2)
|
||||
def test_query_range_wrong_format(self, connect, collection, get_invalid_range):
|
||||
'''
|
||||
|
@ -1349,7 +1366,8 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_range_one_field_not_existed")
|
||||
def test_query_range_one_field_not_existed(self, connect, collection):
|
||||
'''
|
||||
method: build query with two fields ranges, one of fields not existed
|
||||
|
@ -1369,7 +1387,10 @@ class TestSearchDSL(object):
|
|||
************************************************************************
|
||||
"""
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_term_has_common")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_multi_term_has_common(self, connect, collection):
|
||||
'''
|
||||
method: build query with multi term with same field, and values has common
|
||||
|
@ -1384,7 +1405,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_term_no_common")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_multi_term_no_common(self, connect, collection):
|
||||
'''
|
||||
|
@ -1400,7 +1423,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_term_different_fields")
|
||||
def test_query_multi_term_different_fields(self, connect, collection):
|
||||
'''
|
||||
method: build query with multi range with same field, and ranges no common
|
||||
|
@ -1416,7 +1441,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_term_multi_fields")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_single_term_multi_fields(self, connect, collection):
|
||||
'''
|
||||
|
@ -1432,7 +1459,9 @@ class TestSearchDSL(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_range_has_common")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_multi_range_has_common(self, connect, collection):
|
||||
'''
|
||||
|
@ -1448,7 +1477,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_range_no_common")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_multi_range_no_common(self, connect, collection):
|
||||
'''
|
||||
|
@ -1464,7 +1495,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_multi_range_different_fields")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_multi_range_different_fields(self, connect, collection):
|
||||
'''
|
||||
|
@ -1480,7 +1513,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_range_multi_fields")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_single_range_multi_fields(self, connect, collection):
|
||||
'''
|
||||
|
@ -1502,7 +1537,9 @@ class TestSearchDSL(object):
|
|||
******************************************************************
|
||||
"""
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_term_range_has_common")
|
||||
@pytest.mark.level(2)
|
||||
def test_query_single_term_range_has_common(self, connect, collection):
|
||||
'''
|
||||
|
@ -1518,7 +1555,9 @@ class TestSearchDSL(object):
|
|||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# PASS
|
||||
# DOG: TODO TRC
|
||||
# TODO
|
||||
@pytest.mark.skip("query_single_term_range_no_common")
|
||||
def test_query_single_term_range_no_common(self, connect, collection):
|
||||
'''
|
||||
method: build query with single term single range
|
||||
|
@ -1540,6 +1579,7 @@ class TestSearchDSL(object):
|
|||
"""
|
||||
|
||||
# PASS
|
||||
# TODO
|
||||
def test_query_multi_vectors_same_field(self, connect, collection):
|
||||
'''
|
||||
method: build query with two vectors same field
|
||||
|
@ -1576,7 +1616,8 @@ class TestSearchDSLBools(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_should_only_term")
|
||||
def test_query_should_only_term(self, connect, collection):
|
||||
'''
|
||||
method: build query without must, with should.term instead
|
||||
|
@ -1587,7 +1628,8 @@ class TestSearchDSLBools(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_should_only_vector")
|
||||
def test_query_should_only_vector(self, connect, collection):
|
||||
'''
|
||||
method: build query without must, with should.vector instead
|
||||
|
@ -1598,7 +1640,8 @@ class TestSearchDSLBools(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_must_not_only_term")
|
||||
def test_query_must_not_only_term(self, connect, collection):
|
||||
'''
|
||||
method: build query without must, with must_not.term instead
|
||||
|
@ -1609,7 +1652,8 @@ class TestSearchDSLBools(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_must_not_vector")
|
||||
def test_query_must_not_vector(self, connect, collection):
|
||||
'''
|
||||
method: build query without must, with must_not.vector instead
|
||||
|
@ -1620,7 +1664,8 @@ class TestSearchDSLBools(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
# PASS
|
||||
# DOG: TODO INVALID DSL
|
||||
@pytest.mark.skip("query_must_should")
|
||||
def test_query_must_should(self, connect, collection):
|
||||
'''
|
||||
method: build query must, and with should.term
|
||||
|
|
Loading…
Reference in New Issue