Add insertion api, fix unittest and data type

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/4973/head^2
bigsheeper 2020-09-02 17:18:49 +08:00 committed by yefu.chen
parent 36c8362b41
commit b817fa5aed
6 changed files with 87 additions and 117 deletions

View File

@ -85,29 +85,6 @@ TEST(CApiTest, DeleteTest) {
auto partition = NewPartition(collection, partition_name);
auto segment = NewSegment(partition, 0);
std::vector<char> raw_data;
std::vector<uint64_t> timestamps;
std::vector<uint64_t> uids;
int N = 10000;
std::default_random_engine e(67);
for(int i = 0; i < N; ++i) {
uids.push_back(100000 + i);
timestamps.push_back(0);
// append vec
float vec[16];
for(auto &x: vec) {
x = e() % 2000 * 0.001 - 1.0;
}
raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
int age = e() % 100;
raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
}
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto ins_res = Insert(segment, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res == 0);
unsigned long delete_primary_keys[] = {100000, 100001, 100002};
unsigned long delete_timestamps[] = {0, 0, 0};

View File

@ -24,9 +24,6 @@ import (
"github.com/czs007/suvlim/errors"
"github.com/pingcap/log"
//"github.com/czs007/suvlim/util/tsoutil"
//"github.com/tikv/pd/server/cluster"
//"github.com/tikv/pd/server/core"
//"github.com/tikv/pd/server/versioninfo"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

View File

@ -68,7 +68,7 @@ type VectorRowRecord struct {
}
type VectorRecord struct {
Records []*VectorRowRecord
Records *VectorRowRecord
}
type FieldValue struct {
@ -123,7 +123,7 @@ type Message interface {
type InsertMsg struct {
CollectionName string
Fields []*FieldValue
EntityId int64
EntityId uint64
PartitionTag string
SegmentId uint64
Timestamp uint64
@ -133,7 +133,7 @@ type InsertMsg struct {
type DeleteMsg struct {
CollectionName string
EntityId int64
EntityId uint64
Timestamp uint64
ClientId int64
MsgType OpType

View File

@ -102,7 +102,7 @@ func (node *QueryNode) StartMessageClient() {
go node.messageClient.ReceiveMessage()
}
func (node *QueryNode) GetSegmentByEntityId(entityId int64) *Segment {
func (node *QueryNode) GetSegmentByEntityId(entityId uint64) *Segment {
// TODO: get id2segment info from pulsar
return nil
}
@ -185,7 +185,7 @@ func (node *QueryNode) Insert(insertMessages []*schema.InsertMsg, wg *sync.WaitG
var clientId = insertMessages[0].ClientId
// TODO: prevent Memory copy
var entityIds []int64
var entityIds []uint64
var timestamps []uint64
var vectorRecords [][]*schema.FieldValue
@ -224,7 +224,8 @@ func (node *QueryNode) Insert(insertMessages []*schema.InsertMsg, wg *sync.WaitG
return schema.Status{}
}
var result = SegmentInsert(targetSegment, &entityIds, &timestamps, vectorRecords)
// TODO: check error
var result, _ = SegmentInsert(targetSegment, &entityIds, &timestamps, vectorRecords)
wg.Done()
return publishResult(&result, clientId)
@ -232,11 +233,10 @@ func (node *QueryNode) Insert(insertMessages []*schema.InsertMsg, wg *sync.WaitG
func (node *QueryNode) Delete(deleteMessages []*schema.DeleteMsg, wg *sync.WaitGroup) schema.Status {
var timeSync = node.GetTimeSync()
var collectionName = deleteMessages[0].CollectionName
var clientId = deleteMessages[0].ClientId
// TODO: prevent Memory copy
var entityIds []int64
var entityIds []uint64
var timestamps []uint64
for i, msg := range node.buffer.DeleteBuffer {
@ -273,7 +273,8 @@ func (node *QueryNode) Delete(deleteMessages []*schema.DeleteMsg, wg *sync.WaitG
// TODO: does all entities from a common batch are in the same segment?
var targetSegment = node.GetSegmentByEntityId(entityIds[0])
var result = SegmentDelete(targetSegment, &entityIds, &timestamps)
// TODO: check error
var result, _ = SegmentDelete(targetSegment, &entityIds, &timestamps)
wg.Done()
return publishResult(&result, clientId)
@ -323,7 +324,8 @@ func (node *QueryNode) Search(searchMessages []*schema.SearchMsg, wg *sync.WaitG
return schema.Status{}
}
var result = SegmentSearch(targetSegment, queryString, &timestamps, &records)
// TODO: check error
var result, _ = SegmentSearch(targetSegment, queryString, &timestamps, &records)
wg.Done()
return publishSearchResult(result, clientId)

View File

@ -13,6 +13,7 @@ package reader
*/
import "C"
import (
"github.com/czs007/suvlim/errors"
"github.com/czs007/suvlim/pulsar/client-go/schema"
"unsafe"
)
@ -61,7 +62,7 @@ func (s *Segment) Close() {
}
////////////////////////////////////////////////////////////////////////////
func SegmentInsert(segment *Segment, entityIds *[]int64, timestamps *[]uint64, dataChunk [][]*schema.FieldValue) ResultEntityIds {
func SegmentInsert(segment *Segment, entityIds *[]uint64, timestamps *[]uint64, dataChunk [][]*schema.FieldValue) (ResultEntityIds, error) {
// TODO: remove hard code schema
// auto schema_tmp = std::make_shared<Schema>();
// schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16);
@ -78,21 +79,29 @@ func SegmentInsert(segment *Segment, entityIds *[]int64, timestamps *[]uint64, d
signed long int count);
*/
//msgCount := len(dataChunk)
//cEntityIds := (*C.ulong)(entityIds)
//
//// dataChunk to raw data
//var rawData []byte
//var i int
//for i = 0; i < msgCount; i++ {
// rawVector := dataChunk[i][0].VectorRecord.Records
// rawData = append(rawData, rawVector...)
//}
// TODO: remove hard code & fake dataChunk
const DIM = 4
const N = 3
var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4}
var rawData []int8
for i := 0; i <= N; i++ {
for _, ele := range vec {
rawData=append(rawData, int8(ele))
}
rawData=append(rawData, int8(i))
}
const sizeofPerRow = 4 + DIM * 4
return ResultEntityIds{}
var status = C.Insert(segment.SegmentPtr, (*C.ulong)(entityIds), (*C.ulong)(timestamps), unsafe.Pointer(&rawData[0]), C.int(sizeofPerRow), C.long(N))
if status != 0 {
return nil, errors.New("Insert failed, error code = " + status)
}
return ResultEntityIds{}, nil
}
func SegmentDelete(segment *Segment, entityIds *[]int64, timestamps *[]uint64) ResultEntityIds {
func SegmentDelete(segment *Segment, entityIds *[]uint64, timestamps *[]uint64) (ResultEntityIds, error) {
/*C.Delete
int
Delete(CSegmentBase c_segment,
@ -102,13 +111,16 @@ func SegmentDelete(segment *Segment, entityIds *[]int64, timestamps *[]uint64) R
*/
size := len(*entityIds)
// TODO: add query result status check
var _ = C.Delete(segment.SegmentPtr, C.long(size), (*C.ulong)(entityIds), (*C.ulong)(timestamps))
var status = C.Delete(segment.SegmentPtr, C.long(size), (*C.ulong)(entityIds), (*C.ulong)(timestamps))
return ResultEntityIds{}
if status != 0 {
return nil, errors.New("Delete failed, error code = " + status)
}
return ResultEntityIds{}, nil
}
func SegmentSearch(segment *Segment, queryString string, timestamps *[]uint64, vectorRecord *[]schema.VectorRecord) *[]SearchResult {
func SegmentSearch(segment *Segment, queryString string, timestamps *[]uint64, vectorRecord *[]schema.VectorRecord) (*[]SearchResult, error) {
/*C.Search
int
Search(CSegmentBase c_segment,
@ -126,11 +138,13 @@ func SegmentSearch(segment *Segment, queryString string, timestamps *[]uint64, v
resultIds := make([]int64, TopK)
resultDistances := make([]float32, TopK)
// TODO: add query result status check
var _ = C.Search(segment.SegmentPtr, unsafe.Pointer(nil), C.ulong(timestamp), (*C.long)(&resultIds[0]), (*C.float)(&resultDistances[0]))
var status = C.Search(segment.SegmentPtr, unsafe.Pointer(nil), C.ulong(timestamp), (*C.long)(&resultIds[0]), (*C.float)(&resultDistances[0]))
if status != 0 {
return nil, errors.New("Search failed, error code = " + status)
}
results = append(results, SearchResult{ResultIds: resultIds, ResultDistances: resultDistances})
}
return &results
return &results, nil
}

View File

@ -1,6 +1,8 @@
package reader
import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)
@ -15,36 +17,22 @@ func TestConstructorAndDestructor(t *testing.T) {
node.DeleteCollection(collection)
}
//func TestSegmentInsert(t *testing.T) {
// node := NewQueryNode(0, 0)
// var collection = node.NewCollection("collection0", "fake schema")
// var partition = collection.NewPartition("partition0")
// var segment = partition.NewSegment(0)
//
// const DIM = 4
// const N = 3
//
// var ids = [N]uint64{1, 2, 3}
// var timestamps = [N]uint64{0, 0, 0}
//
// var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4}
// var rawData []int8
//
// for i := 0; i <= N; i++ {
// for _, ele := range vec {
// rawData=append(rawData, int8(ele))
// }
// rawData=append(rawData, int8(i))
// }
//
// const sizeofPerRow = 4 + DIM * 4
// var res = Insert(segment, N, (*C.ulong)(&ids[0]), (*C.ulong)(&timestamps[0]), unsafe.Pointer(&rawData[0]), C.int(sizeofPerRow), C.long(N))
// assert.Equal()
//
// partition.DeleteSegment(segment)
// collection.DeletePartition(partition)
// node.DeleteCollection(collection)
//}
func TestSegmentInsert(t *testing.T) {
node := NewQueryNode(0, 0)
var collection = node.NewCollection("collection0", "fake schema")
var partition = collection.NewPartition("partition0")
var segment = partition.NewSegment(0)
ids :=[] uint64{1, 2, 3}
timestamps :=[] uint64 {0, 0, 0}
var _, err = SegmentInsert(segment, &ids, &timestamps, nil)
assert.NoError(t, err)
partition.DeleteSegment(segment)
collection.DeletePartition(partition)
node.DeleteCollection(collection)
}
func TestSegmentDelete(t *testing.T) {
node := NewQueryNode(0, 0)
@ -52,42 +40,34 @@ func TestSegmentDelete(t *testing.T) {
var partition = collection.NewPartition("partition0")
var segment = partition.NewSegment(0)
ids :=[] int64{1, 2, 3}
ids :=[] uint64{1, 2, 3}
timestamps :=[] uint64 {0, 0, 0}
SegmentDelete(segment, &ids, &timestamps)
var _, err = SegmentDelete(segment, &ids, &timestamps)
assert.NoError(t, err)
partition.DeleteSegment(segment)
collection.DeletePartition(partition)
node.DeleteCollection(collection)
}
//func TestSegmentSearch(t *testing.T) {
// node := NewQueryNode(0, 0)
// var collection = node.NewCollection("collection0", "fake schema")
// var partition = collection.NewPartition("partition0")
// var segment = partition.NewSegment(0)
//
// const DIM = 4
// const N = 3
//
// var ids = [N]uint64{1, 2, 3}
// var timestamps = [N]uint64{0, 0, 0}
//
// var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4}
// var rawData []int8
//
// for i := 0; i <= N; i++ {
// for _, ele := range vec {
// rawData=append(rawData, int8(ele))
// }
// rawData=append(rawData, int8(i))
// }
//
// const sizeofPerRow = 4 + DIM * 4
// SegmentSearch(segment, "fake query string", &timestamps, nil)
//
// partition.DeleteSegment(segment)
// collection.DeletePartition(partition)
// node.DeleteCollection(collection)
//}
func TestSegmentSearch(t *testing.T) {
node := NewQueryNode(0, 0)
var collection = node.NewCollection("collection0", "fake schema")
var partition = collection.NewPartition("partition0")
var segment = partition.NewSegment(0)
ids :=[] uint64{1, 2, 3}
timestamps :=[] uint64 {0, 0, 0}
var _, insertErr = SegmentInsert(segment, &ids, &timestamps, nil)
assert.NoError(t, insertErr)
var searchRes, searchErr = SegmentSearch(segment, "fake query string", &timestamps, nil)
assert.NoError(t, searchErr)
fmt.Println(searchRes)
partition.DeleteSegment(segment)
collection.DeletePartition(partition)
node.DeleteCollection(collection)
}