mirror of https://github.com/milvus-io/milvus.git
Fix bug and delete unused code
Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>pull/4973/head^2
parent
4ed11d9775
commit
f12366342f
1
go.mod
1
go.mod
|
@ -7,6 +7,7 @@ require (
|
|||
github.com/apache/pulsar-client-go v0.1.1
|
||||
github.com/aws/aws-sdk-go v1.30.8
|
||||
github.com/coreos/etcd v3.3.25+incompatible // indirect
|
||||
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548
|
||||
github.com/frankban/quicktest v1.10.2 // indirect
|
||||
github.com/fsnotify/fsnotify v1.4.9 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -65,6 +65,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbp
|
|||
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso=
|
||||
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
|
@ -329,6 +330,7 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx
|
|||
github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8=
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 h1:/NRJ5vAYoqz+7sG51ubIDHXeWO8DlTSrToPu6q11ziA=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
|
|
|
@ -57,7 +57,7 @@ type segRequest struct {
|
|||
count uint32
|
||||
colName string
|
||||
partition string
|
||||
segID UniqueID
|
||||
segInfo map[UniqueID]uint32
|
||||
channelID int32
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
package allocator
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/cznic/mathutil"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
|
@ -18,7 +22,10 @@ const (
|
|||
)
|
||||
|
||||
type assignInfo struct {
|
||||
internalpb.SegIDAssignment
|
||||
collName string
|
||||
partitionTag string
|
||||
channelID int32
|
||||
segInfo map[UniqueID]uint32 // segmentID->count map
|
||||
expireTime time.Time
|
||||
lastInsertTime time.Time
|
||||
}
|
||||
|
@ -32,12 +39,16 @@ func (info *assignInfo) IsActive(now time.Time) bool {
|
|||
}
|
||||
|
||||
func (info *assignInfo) IsEnough(count uint32) bool {
|
||||
return info.Count >= count
|
||||
total := uint32(0)
|
||||
for _, count := range info.segInfo {
|
||||
total += count
|
||||
}
|
||||
return total >= count
|
||||
}
|
||||
|
||||
type SegIDAssigner struct {
|
||||
Allocator
|
||||
assignInfos map[string][]*assignInfo // collectionName -> [] *assignInfo
|
||||
assignInfos map[string]*list.List // collectionName -> *list.List
|
||||
segReqs []*internalpb.SegIDRequest
|
||||
canDoReqs []request
|
||||
}
|
||||
|
@ -50,11 +61,8 @@ func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, e
|
|||
cancel: cancel,
|
||||
masterAddress: masterAddr,
|
||||
countPerRPC: SegCountPerRPC,
|
||||
//toDoReqs: []request,
|
||||
},
|
||||
assignInfos: make(map[string][]*assignInfo),
|
||||
//segReqs: make([]*internalpb.SegIDRequest, maxConcurrentRequests),
|
||||
//canDoReqs: make([]request, maxConcurrentRequests),
|
||||
assignInfos: make(map[string]*list.List),
|
||||
}
|
||||
sa.tChan = &ticker{
|
||||
updateInterval: time.Second,
|
||||
|
@ -67,16 +75,17 @@ func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, e
|
|||
|
||||
func (sa *SegIDAssigner) collectExpired() {
|
||||
now := time.Now()
|
||||
for _, colInfos := range sa.assignInfos {
|
||||
for _, assign := range colInfos {
|
||||
for _, info := range sa.assignInfos {
|
||||
for e := info.Front(); e != nil; e = e.Next() {
|
||||
assign := e.Value.(*assignInfo)
|
||||
if !assign.IsActive(now) || !assign.IsExpired(now) {
|
||||
continue
|
||||
}
|
||||
sa.segReqs = append(sa.segReqs, &internalpb.SegIDRequest{
|
||||
ChannelID: assign.ChannelID,
|
||||
ChannelID: assign.channelID,
|
||||
Count: sa.countPerRPC,
|
||||
CollName: assign.CollName,
|
||||
PartitionTag: assign.PartitionTag,
|
||||
CollName: assign.collName,
|
||||
PartitionTag: assign.partitionTag,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -88,7 +97,6 @@ func (sa *SegIDAssigner) checkToDoReqs() {
|
|||
}
|
||||
now := time.Now()
|
||||
for _, req := range sa.toDoReqs {
|
||||
fmt.Println("DDDDD????", req)
|
||||
segRequest := req.(*segRequest)
|
||||
assign := sa.getAssign(segRequest.colName, segRequest.partition, segRequest.channelID)
|
||||
if assign == nil || assign.IsExpired(now) || !assign.IsEnough(segRequest.count) {
|
||||
|
@ -102,13 +110,36 @@ func (sa *SegIDAssigner) checkToDoReqs() {
|
|||
}
|
||||
}
|
||||
|
||||
func (sa *SegIDAssigner) removeSegInfo(colName, partition string, channelID int32) {
|
||||
assignInfos, ok := sa.assignInfos[colName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cnt := assignInfos.Len()
|
||||
if cnt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for e := assignInfos.Front(); e != nil; e = e.Next() {
|
||||
assign := e.Value.(*assignInfo)
|
||||
if assign.partitionTag != partition || assign.channelID != channelID {
|
||||
continue
|
||||
}
|
||||
assignInfos.Remove(e)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (sa *SegIDAssigner) getAssign(colName, partition string, channelID int32) *assignInfo {
|
||||
colInfos, ok := sa.assignInfos[colName]
|
||||
assignInfos, ok := sa.assignInfos[colName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
for _, info := range colInfos {
|
||||
if info.PartitionTag != partition || info.ChannelID != channelID {
|
||||
|
||||
for e := assignInfos.Front(); e != nil; e = e.Next() {
|
||||
info := e.Value.(*assignInfo)
|
||||
if info.partitionTag != partition || info.channelID != channelID {
|
||||
continue
|
||||
}
|
||||
return info
|
||||
|
@ -151,19 +182,26 @@ func (sa *SegIDAssigner) syncSegments() {
|
|||
|
||||
now := time.Now()
|
||||
expiredTime := now.Add(time.Millisecond * time.Duration(resp.ExpireDuration))
|
||||
for _, info := range resp.PerChannelAssignment {
|
||||
sa.removeSegInfo(info.CollName, info.PartitionTag, info.ChannelID)
|
||||
}
|
||||
|
||||
for _, info := range resp.PerChannelAssignment {
|
||||
assign := sa.getAssign(info.CollName, info.PartitionTag, info.ChannelID)
|
||||
if assign == nil {
|
||||
colInfos := sa.assignInfos[info.CollName]
|
||||
segInfo := make(map[UniqueID]uint32)
|
||||
segInfo[info.SegID] = info.Count
|
||||
newAssign := &assignInfo{
|
||||
SegIDAssignment: *info,
|
||||
expireTime: expiredTime,
|
||||
lastInsertTime: now,
|
||||
collName: info.CollName,
|
||||
partitionTag: info.PartitionTag,
|
||||
channelID: info.ChannelID,
|
||||
segInfo: segInfo,
|
||||
}
|
||||
colInfos = append(colInfos, newAssign)
|
||||
colInfos.PushBack(newAssign)
|
||||
sa.assignInfos[info.CollName] = colInfos
|
||||
} else {
|
||||
assign.SegIDAssignment = *info
|
||||
assign.segInfo[info.SegID] = info.Count
|
||||
assign.expireTime = expiredTime
|
||||
assign.lastInsertTime = now
|
||||
}
|
||||
|
@ -181,13 +219,38 @@ func (sa *SegIDAssigner) processFunc(req request) error {
|
|||
if assign == nil {
|
||||
return errors.New("Failed to GetSegmentID")
|
||||
}
|
||||
segRequest.segID = assign.SegID
|
||||
assign.Count -= segRequest.count
|
||||
|
||||
keys := make([]UniqueID, len(assign.segInfo))
|
||||
i := 0
|
||||
for key := range assign.segInfo {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
reqCount := segRequest.count
|
||||
|
||||
resultSegInfo := make(map[UniqueID]uint32)
|
||||
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
|
||||
for _, key := range keys {
|
||||
if reqCount <= 0 {
|
||||
break
|
||||
}
|
||||
cur := assign.segInfo[key]
|
||||
minCnt := mathutil.MinUint32(cur, reqCount)
|
||||
resultSegInfo[key] = minCnt
|
||||
cur -= minCnt
|
||||
reqCount -= minCnt
|
||||
if cur <= 0 {
|
||||
delete(assign.segInfo, key)
|
||||
} else {
|
||||
assign.segInfo[key] = cur
|
||||
}
|
||||
}
|
||||
segRequest.segInfo = resultSegInfo
|
||||
fmt.Println("process segmentID")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (UniqueID, error) {
|
||||
func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (map[UniqueID]uint32, error) {
|
||||
req := &segRequest{
|
||||
baseRequest: baseRequest{done: make(chan error), valid: false},
|
||||
colName: colName,
|
||||
|
@ -199,7 +262,7 @@ func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32
|
|||
req.Wait()
|
||||
|
||||
if !req.IsValid() {
|
||||
return 0, errors.New("GetSegmentID Failed")
|
||||
return nil, errors.New("GetSegmentID Failed")
|
||||
}
|
||||
return req.segID, nil
|
||||
return req.segInfo, nil
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
type Timestamp = typeutil.Timestamp
|
||||
|
||||
const (
|
||||
tsCountPerRPC = 2 << 18 * 10
|
||||
tsCountPerRPC = 2 << 15
|
||||
)
|
||||
|
||||
type TimestampAllocator struct {
|
||||
|
@ -37,6 +37,7 @@ func NewTimestampAllocator(ctx context.Context, masterAddr string) (*TimestampAl
|
|||
}
|
||||
a.Allocator.syncFunc = a.syncTs
|
||||
a.Allocator.processFunc = a.processFunc
|
||||
a.Allocator.checkFunc = a.checkFunc
|
||||
return a, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
|
||||
#include <knowhere/index/vector_index/VecIndexFactory.h>
|
||||
#include <cstdint>
|
||||
#include <boost/concept_check.hpp>
|
||||
|
||||
CSegmentBase
|
||||
NewSegment(CCollection collection, uint64_t segment_id) {
|
||||
|
@ -42,7 +41,7 @@ DeleteSegment(CSegmentBase segment) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
CStatus
|
||||
int
|
||||
Insert(CSegmentBase c_segment,
|
||||
int64_t reserved_offset,
|
||||
int64_t size,
|
||||
|
@ -58,22 +57,11 @@ Insert(CSegmentBase c_segment,
|
|||
dataChunk.sizeof_per_row = sizeof_per_row;
|
||||
dataChunk.count = count;
|
||||
|
||||
try {
|
||||
auto res = segment->Insert(reserved_offset, size, row_ids, timestamps, dataChunk);
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
return status;
|
||||
} catch (std::runtime_error& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
return status;
|
||||
}
|
||||
auto res = segment->Insert(reserved_offset, size, row_ids, timestamps, dataChunk);
|
||||
|
||||
// TODO: delete print
|
||||
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
|
||||
return res.code();
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -85,24 +73,13 @@ PreInsert(CSegmentBase c_segment, int64_t size) {
|
|||
return segment->PreInsert(size);
|
||||
}
|
||||
|
||||
CStatus
|
||||
int
|
||||
Delete(
|
||||
CSegmentBase c_segment, int64_t reserved_offset, int64_t size, const int64_t* row_ids, const uint64_t* timestamps) {
|
||||
auto segment = (milvus::segcore::SegmentBase*)c_segment;
|
||||
|
||||
try {
|
||||
auto res = segment->Delete(reserved_offset, size, row_ids, timestamps);
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
return status;
|
||||
} catch (std::runtime_error& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
return status;
|
||||
}
|
||||
auto res = segment->Delete(reserved_offset, size, row_ids, timestamps);
|
||||
return res.code();
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -114,7 +91,7 @@ PreDelete(CSegmentBase c_segment, int64_t size) {
|
|||
return segment->PreDelete(size);
|
||||
}
|
||||
|
||||
CStatus
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CPlan c_plan,
|
||||
CPlaceholderGroup* c_placeholder_groups,
|
||||
|
@ -130,22 +107,14 @@ Search(CSegmentBase c_segment,
|
|||
}
|
||||
milvus::segcore::QueryResult query_result;
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::runtime_error& e) {
|
||||
status.error_code = UnexpectedException;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result);
|
||||
|
||||
// result_ids and result_distances have been allocated memory in goLang,
|
||||
// so we don't need to malloc here.
|
||||
memcpy(result_ids, query_result.result_ids_.data(), query_result.get_row_count() * sizeof(int64_t));
|
||||
memcpy(result_distances, query_result.result_distances_.data(), query_result.get_row_count() * sizeof(float));
|
||||
|
||||
return status;
|
||||
return res.code();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -14,24 +14,12 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "segcore/collection_c.h"
|
||||
#include "segcore/plan_c.h"
|
||||
#include <stdint.h>
|
||||
|
||||
typedef void* CSegmentBase;
|
||||
|
||||
enum ErrorCode {
|
||||
Success = 0,
|
||||
UnexpectedException = 1,
|
||||
};
|
||||
|
||||
typedef struct CStatus {
|
||||
int error_code;
|
||||
const char* error_msg;
|
||||
} CStatus;
|
||||
|
||||
CSegmentBase
|
||||
NewSegment(CCollection collection, uint64_t segment_id);
|
||||
|
||||
|
@ -40,7 +28,7 @@ DeleteSegment(CSegmentBase segment);
|
|||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
CStatus
|
||||
int
|
||||
Insert(CSegmentBase c_segment,
|
||||
int64_t reserved_offset,
|
||||
int64_t size,
|
||||
|
@ -53,14 +41,14 @@ Insert(CSegmentBase c_segment,
|
|||
int64_t
|
||||
PreInsert(CSegmentBase c_segment, int64_t size);
|
||||
|
||||
CStatus
|
||||
int
|
||||
Delete(
|
||||
CSegmentBase c_segment, int64_t reserved_offset, int64_t size, const int64_t* row_ids, const uint64_t* timestamps);
|
||||
|
||||
int64_t
|
||||
PreDelete(CSegmentBase c_segment, int64_t size);
|
||||
|
||||
CStatus
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CPlan plan,
|
||||
CPlaceholderGroup* placeholder_groups,
|
||||
|
|
|
@ -65,7 +65,7 @@ TEST(CApiTest, InsertTest) {
|
|||
|
||||
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
|
||||
assert(res.error_code == Success);
|
||||
assert(res == 0);
|
||||
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
|
@ -82,7 +82,7 @@ TEST(CApiTest, DeleteTest) {
|
|||
auto offset = PreDelete(segment, 3);
|
||||
|
||||
auto del_res = Delete(segment, offset, 3, delete_row_ids, delete_timestamps);
|
||||
assert(del_res.error_code == Success);
|
||||
assert(del_res == 0);
|
||||
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
|
@ -116,7 +116,7 @@ TEST(CApiTest, SearchTest) {
|
|||
auto offset = PreInsert(segment, N);
|
||||
|
||||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(ins_res.error_code == Success);
|
||||
assert(ins_res == 0);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
|
@ -163,7 +163,7 @@ TEST(CApiTest, SearchTest) {
|
|||
float result_distances[100];
|
||||
|
||||
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
|
||||
assert(sea_res.error_code == Success);
|
||||
assert(sea_res == 0);
|
||||
|
||||
DeletePlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
|
@ -199,7 +199,7 @@ TEST(CApiTest, BuildIndexTest) {
|
|||
auto offset = PreInsert(segment, N);
|
||||
|
||||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(ins_res.error_code == Success);
|
||||
assert(ins_res == 0);
|
||||
|
||||
// TODO: add index ptr
|
||||
Close(segment);
|
||||
|
@ -250,7 +250,7 @@ TEST(CApiTest, BuildIndexTest) {
|
|||
float result_distances[100];
|
||||
|
||||
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
|
||||
assert(sea_res.error_code == Success);
|
||||
assert(sea_res == 0);
|
||||
|
||||
DeletePlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
|
@ -315,7 +315,7 @@ TEST(CApiTest, GetMemoryUsageInBytesTest) {
|
|||
|
||||
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
|
||||
assert(res.error_code == Success);
|
||||
assert(res == 0);
|
||||
|
||||
auto memory_usage_size = GetMemoryUsageInBytes(segment);
|
||||
|
||||
|
@ -482,7 +482,7 @@ TEST(CApiTest, GetDeletedCountTest) {
|
|||
auto offset = PreDelete(segment, 3);
|
||||
|
||||
auto del_res = Delete(segment, offset, 3, delete_row_ids, delete_timestamps);
|
||||
assert(del_res.error_code == Success);
|
||||
assert(del_res == 0);
|
||||
|
||||
// TODO: assert(deleted_count == len(delete_row_ids))
|
||||
auto deleted_count = GetDeletedCount(segment);
|
||||
|
@ -502,7 +502,7 @@ TEST(CApiTest, GetRowCountTest) {
|
|||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
||||
auto offset = PreInsert(segment, N);
|
||||
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(res.error_code == Success);
|
||||
assert(res == 0);
|
||||
|
||||
auto row_count = GetRowCount(segment);
|
||||
assert(row_count == N);
|
||||
|
|
|
@ -96,6 +96,27 @@ func (pt *ParamTable) ProxyIDList() []UniqueID {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) queryNodeNum() int {
|
||||
return len(pt.queryNodeIDList())
|
||||
}
|
||||
|
||||
func (pt *ParamTable) queryNodeIDList() []UniqueID {
|
||||
queryNodeIDStr, err := pt.Load("nodeID.queryNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
queryNodeIDs := strings.Split(queryNodeIDStr, ",")
|
||||
for _, i := range queryNodeIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) ProxyID() UniqueID {
|
||||
proxyID, err := pt.Load("_proxyID")
|
||||
if err != nil {
|
||||
|
@ -396,11 +417,11 @@ func (pt *ParamTable) searchChannelNames() []string {
|
|||
}
|
||||
|
||||
func (pt *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.search")
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.searchResult")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -55,12 +55,11 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
|
|||
proxyLoopCancel: cancel,
|
||||
}
|
||||
|
||||
// TODO: use config instead
|
||||
pulsarAddress := Params.PulsarAddress()
|
||||
|
||||
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
|
||||
p.queryMsgStream.SetPulsarClient(pulsarAddress)
|
||||
p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
|
||||
p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames())
|
||||
|
||||
masterAddr := Params.MasterAddress()
|
||||
idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr)
|
||||
|
@ -84,7 +83,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
|
|||
|
||||
p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize())
|
||||
p.manipulationMsgStream.SetPulsarClient(pulsarAddress)
|
||||
p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
|
||||
p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames())
|
||||
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false)
|
||||
}
|
||||
|
|
|
@ -229,7 +229,7 @@ func TestProxy_Insert(t *testing.T) {
|
|||
collectionName := "CreateCollection" + strconv.FormatInt(int64(i), 10)
|
||||
req := &servicepb.RowBatch{
|
||||
CollectionName: collectionName,
|
||||
PartitionTag: "",
|
||||
PartitionTag: "haha",
|
||||
RowData: make([]*commonpb.Blob, 0),
|
||||
HashKeys: make([]int32, 0),
|
||||
}
|
||||
|
@ -237,6 +237,7 @@ func TestProxy_Insert(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
createCollection(t, collectionName)
|
||||
has := hasCollection(t, collectionName)
|
||||
if has {
|
||||
resp, err := proxyClient.Insert(ctx, req)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sort"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/allocator"
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
|
@ -15,6 +18,9 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
|||
|
||||
result := make(map[int32]*msgstream.MsgPack)
|
||||
|
||||
channelCountMap := make(map[UniqueID]map[int32]uint32) // reqID --> channelID to count
|
||||
reqSchemaMap := make(map[UniqueID][]string)
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
if request.Type() != internalpb.MsgType_kInsert {
|
||||
return nil, errors.New(string("msg's must be Insert"))
|
||||
|
@ -23,8 +29,8 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
|||
if !ok {
|
||||
return nil, errors.New(string("msg's must be Insert"))
|
||||
}
|
||||
keys := hashKeys[i]
|
||||
|
||||
keys := hashKeys[i]
|
||||
timestampLen := len(insertRequest.Timestamps)
|
||||
rowIDLen := len(insertRequest.RowIDs)
|
||||
rowDataLen := len(insertRequest.RowData)
|
||||
|
@ -34,10 +40,84 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
|||
return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal"))
|
||||
}
|
||||
|
||||
reqID := insertRequest.ReqID
|
||||
if _, ok := channelCountMap[reqID]; !ok {
|
||||
channelCountMap[reqID] = make(map[int32]uint32)
|
||||
}
|
||||
|
||||
if _, ok := reqSchemaMap[reqID]; !ok {
|
||||
reqSchemaMap[reqID] = []string{insertRequest.CollectionName, insertRequest.PartitionTag}
|
||||
}
|
||||
|
||||
for _, channelID := range keys {
|
||||
channelCountMap[reqID][channelID]++
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
reqSegCountMap := make(map[UniqueID]map[int32]map[UniqueID]uint32)
|
||||
|
||||
for reqID, countInfo := range channelCountMap {
|
||||
schema := reqSchemaMap[reqID]
|
||||
collName, partitionTag := schema[0], schema[1]
|
||||
for channelID, count := range countInfo {
|
||||
mapInfo, err := segIDAssigner.GetSegmentID(collName, partitionTag, channelID, count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqSegCountMap[reqID][channelID] = mapInfo
|
||||
}
|
||||
}
|
||||
|
||||
reqSegAccumulateCountMap := make(map[UniqueID]map[int32][]uint32)
|
||||
reqSegIDMap := make(map[UniqueID]map[int32][]UniqueID)
|
||||
reqSegAllocateCounter := make(map[UniqueID]map[int32]uint32)
|
||||
|
||||
for reqID, channelInfo := range reqSegCountMap {
|
||||
for channelID, segInfo := range channelInfo {
|
||||
reqSegAllocateCounter[reqID][channelID] = 0
|
||||
keys := make([]UniqueID, len(segInfo))
|
||||
i := 0
|
||||
for key := range segInfo {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
|
||||
accumulate := uint32(0)
|
||||
for _, key := range keys {
|
||||
accumulate += segInfo[key]
|
||||
reqSegAccumulateCountMap[reqID][channelID] = append(
|
||||
reqSegAccumulateCountMap[reqID][channelID],
|
||||
accumulate,
|
||||
)
|
||||
reqSegIDMap[reqID][channelID] = append(
|
||||
reqSegIDMap[reqID][channelID],
|
||||
key,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var getSegmentID = func(reqID UniqueID, channelID int32) UniqueID {
|
||||
reqSegAllocateCounter[reqID][channelID]++
|
||||
cur := reqSegAllocateCounter[reqID][channelID]
|
||||
accumulateSlice := reqSegAccumulateCountMap[reqID][channelID]
|
||||
segIDSlice := reqSegIDMap[reqID][channelID]
|
||||
for index, count := range accumulateSlice {
|
||||
if cur <= count {
|
||||
return segIDSlice[index]
|
||||
}
|
||||
}
|
||||
log.Panic("Can't Found SegmentID")
|
||||
return 0
|
||||
}
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
insertRequest := request.(*msgstream.InsertMsg)
|
||||
keys := hashKeys[i]
|
||||
reqID := insertRequest.ReqID
|
||||
collectionName := insertRequest.CollectionName
|
||||
partitionTag := insertRequest.PartitionTag
|
||||
channelID := insertRequest.ChannelID
|
||||
proxyID := insertRequest.ProxyID
|
||||
for index, key := range keys {
|
||||
ts := insertRequest.Timestamps[index]
|
||||
|
@ -48,13 +128,14 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
|||
msgPack := msgstream.MsgPack{}
|
||||
result[key] = &msgPack
|
||||
}
|
||||
segmentID := getSegmentID(reqID, key)
|
||||
sliceRequest := internalpb.InsertRequest{
|
||||
MsgType: internalpb.MsgType_kInsert,
|
||||
ReqID: reqID,
|
||||
CollectionName: collectionName,
|
||||
PartitionTag: partitionTag,
|
||||
SegmentID: 0, // will be assigned later if together
|
||||
ChannelID: channelID,
|
||||
SegmentID: segmentID,
|
||||
ChannelID: int64(key),
|
||||
ProxyID: proxyID,
|
||||
Timestamps: []uint64{ts},
|
||||
RowIDs: []int64{rowID},
|
||||
|
@ -73,25 +154,10 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
|
|||
accMsgs.RowData = append(accMsgs.RowData, row)
|
||||
}
|
||||
} else { // every row is a message
|
||||
segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), 1)
|
||||
insertMsg.SegmentID = segID
|
||||
result[key].Msgs = append(result[key].Msgs, insertMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if together {
|
||||
for key := range result {
|
||||
insertMsg, _ := result[key].Msgs[0].(*msgstream.InsertMsg)
|
||||
rowNums := len(insertMsg.RowIDs)
|
||||
collectionName := insertMsg.CollectionName
|
||||
partitionTag := insertMsg.PartitionTag
|
||||
channelID := insertMsg.ChannelID
|
||||
segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), uint32(rowNums))
|
||||
insertMsg.SegmentID = segID
|
||||
result[key].Msgs[0] = insertMsg
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -34,7 +34,6 @@ type BaseInsertTask = msgstream.InsertMsg
|
|||
type InsertTask struct {
|
||||
BaseInsertTask
|
||||
Condition
|
||||
ts Timestamp
|
||||
result *servicepb.IntegerRangeResponse
|
||||
manipulationMsgStream *msgstream.PulsarMsgStream
|
||||
ctx context.Context
|
||||
|
@ -46,15 +45,21 @@ func (it *InsertTask) SetID(uid UniqueID) {
|
|||
}
|
||||
|
||||
func (it *InsertTask) SetTs(ts Timestamp) {
|
||||
it.ts = ts
|
||||
rowNum := len(it.RowData)
|
||||
it.Timestamps = make([]uint64, rowNum)
|
||||
for index := range it.Timestamps {
|
||||
it.Timestamps[index] = ts
|
||||
}
|
||||
it.BeginTimestamp = ts
|
||||
it.EndTimestamp = ts
|
||||
}
|
||||
|
||||
func (it *InsertTask) BeginTs() Timestamp {
|
||||
return it.ts
|
||||
return it.BeginTimestamp
|
||||
}
|
||||
|
||||
func (it *InsertTask) EndTs() Timestamp {
|
||||
return it.ts
|
||||
return it.EndTimestamp
|
||||
}
|
||||
|
||||
func (it *InsertTask) ID() UniqueID {
|
||||
|
|
|
@ -186,16 +186,7 @@ type DqTaskQueue struct {
|
|||
func (queue *DdTaskQueue) Enqueue(t task) error {
|
||||
queue.lock.Lock()
|
||||
defer queue.lock.Unlock()
|
||||
|
||||
ts, _ := queue.sched.tsoAllocator.AllocOne()
|
||||
log.Printf("[Proxy] allocate timestamp: %v", ts)
|
||||
t.SetTs(ts)
|
||||
|
||||
reqID, _ := queue.sched.idAllocator.AllocOne()
|
||||
log.Printf("[Proxy] allocate reqID: %v", reqID)
|
||||
t.SetID(reqID)
|
||||
|
||||
return queue.addUnissuedTask(t)
|
||||
return queue.BaseTaskQueue.Enqueue(t)
|
||||
}
|
||||
|
||||
func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue {
|
||||
|
@ -369,14 +360,14 @@ func (sched *TaskScheduler) queryLoop() {
|
|||
func (sched *TaskScheduler) queryResultLoop() {
|
||||
defer sched.wg.Done()
|
||||
|
||||
// TODO: use config instead
|
||||
unmarshal := msgstream.NewUnmarshalDispatcher()
|
||||
queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
|
||||
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
|
||||
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
|
||||
queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(),
|
||||
Params.ProxySubName(),
|
||||
unmarshal,
|
||||
Params.MsgStreamSearchResultPulsarBufSize())
|
||||
queryNodeNum := Params.queryNodeNum()
|
||||
|
||||
queryResultMsgStream.Start()
|
||||
defer queryResultMsgStream.Close()
|
||||
|
@ -401,8 +392,7 @@ func (sched *TaskScheduler) queryResultLoop() {
|
|||
queryResultBuf[reqID] = make([]*internalpb.SearchResult, 0)
|
||||
}
|
||||
queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResult)
|
||||
if len(queryResultBuf[reqID]) == 4 {
|
||||
// TODO: use the number of query node instead
|
||||
if len(queryResultBuf[reqID]) == queryNodeNum {
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
if t != nil {
|
||||
qt, ok := t.(*QueryTask)
|
||||
|
|
|
@ -106,7 +106,6 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn
|
|||
if err != nil {
|
||||
log.Println("cannot find segment:", segmentID)
|
||||
// TODO: add error handling
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -117,9 +116,8 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn
|
|||
|
||||
err = targetSegment.segmentInsert(offsets, &ids, ×tamps, &records)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Println("insert failed")
|
||||
// TODO: add error handling
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -273,11 +273,11 @@ func (p *ParamTable) searchChannelNames() []string {
|
|||
}
|
||||
|
||||
func (p *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := p.Load("msgChannel.channelRange.search")
|
||||
channelRange, err := p.Load("msgChannel.channelRange.searchResult")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -109,7 +109,7 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
|
|||
//-------------------------------------------------------------------------------------- dm & search functions
|
||||
func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps *[]Timestamp, records *[]*commonpb.Blob) error {
|
||||
/*
|
||||
CStatus
|
||||
int
|
||||
Insert(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
signed long int size,
|
||||
|
@ -148,12 +148,8 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps
|
|||
cSizeofPerRow,
|
||||
cNumOfRows)
|
||||
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
if status != 0 {
|
||||
return errors.New("Insert failed, error code = " + strconv.Itoa(int(status)))
|
||||
}
|
||||
|
||||
s.recentlyModified = true
|
||||
|
@ -162,7 +158,7 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps
|
|||
|
||||
func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps *[]Timestamp) error {
|
||||
/*
|
||||
CStatus
|
||||
int
|
||||
Delete(CSegmentBase c_segment,
|
||||
long int reserved_offset,
|
||||
long size,
|
||||
|
@ -176,12 +172,8 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps
|
|||
|
||||
var status = C.Delete(s.segmentPtr, cOffset, cSize, cEntityIdsPtr, cTimestampsPtr)
|
||||
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("Delete failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
if status != 0 {
|
||||
return errors.New("Delete failed, error code = " + strconv.Itoa(int(status)))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -195,8 +187,7 @@ func (s *Segment) segmentSearch(plan *Plan,
|
|||
numQueries int64,
|
||||
topK int64) error {
|
||||
/*
|
||||
CStatus
|
||||
Search(void* plan,
|
||||
void* Search(void* plan,
|
||||
void* placeholder_groups,
|
||||
uint64_t* timestamps,
|
||||
int num_groups,
|
||||
|
@ -220,20 +211,16 @@ func (s *Segment) segmentSearch(plan *Plan,
|
|||
var cNumGroups = C.int(len(placeHolderGroups))
|
||||
|
||||
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cNewResultIds, cNewResultDistances)
|
||||
errorCode := status.error_code
|
||||
|
||||
if errorCode != 0 {
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New("Search failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
||||
if status != 0 {
|
||||
return errors.New("search failed, error code = " + strconv.Itoa(int(status)))
|
||||
}
|
||||
|
||||
cNumQueries := C.long(numQueries)
|
||||
cTopK := C.long(topK)
|
||||
// reduce search result
|
||||
mergeStatus := C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds)
|
||||
if mergeStatus != 0 {
|
||||
return errors.New("merge search result failed, error code = " + strconv.Itoa(int(mergeStatus)))
|
||||
status = C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds)
|
||||
if status != 0 {
|
||||
return errors.New("merge search result failed, error code = " + strconv.Itoa(int(status)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -463,6 +463,7 @@ func TestSegment_segmentInsert(t *testing.T) {
|
|||
|
||||
err := segment.segmentInsert(offset, &ids, ×tamps, &records)
|
||||
assert.NoError(t, err)
|
||||
|
||||
deleteSegment(segment)
|
||||
deleteCollection(collection)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue