diff --git a/cmd/master/main.go b/cmd/master/main.go index ff3ca927c7..d05fd102e5 100644 --- a/cmd/master/main.go +++ b/cmd/master/main.go @@ -6,7 +6,6 @@ import ( "log" "os" "os/signal" - "strconv" "syscall" "github.com/zilliztech/milvus-distributed/internal/conf" @@ -25,11 +24,7 @@ func main() { // Creates server. ctx, cancel := context.WithCancel(context.Background()) - etcdAddr := conf.Config.Etcd.Address - etcdAddr += ":" - etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10) - - svr, err := master.CreateServer(ctx, conf.Config.Etcd.Rootpath, conf.Config.Etcd.Rootpath, conf.Config.Etcd.Rootpath, []string{etcdAddr}) + svr, err := master.CreateServer(ctx) if err != nil { log.Print("create server failed", zap.Error(err)) } @@ -47,9 +42,7 @@ func main() { cancel() }() - grpcPort := int64(conf.Config.Master.Port) - - if err := svr.Run(grpcPort); err != nil { + if err := svr.Run(); err != nil { log.Fatal("run server failed", zap.Error(err)) } diff --git a/configs/config.yaml b/configs/config.yaml index 7973e34075..14acd58a61 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -21,7 +21,7 @@ master: etcd: address: localhost - port: 2379 + port: 12379 rootpath: by-dev segthreshold: 10000 diff --git a/deployments/docker/docker-compose.yml b/deployments/docker/docker-compose.yml index e801ddfc86..9ad68a5463 100644 --- a/deployments/docker/docker-compose.yml +++ b/deployments/docker/docker-compose.yml @@ -3,11 +3,11 @@ version: '3.5' services: etcd: image: quay.io/coreos/etcd:latest - command: etcd -listen-peer-urls=http://127.0.0.1:2380 -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379,http://0.0.0.0:4001 -initial-advertise-peer-urls=http://127.0.0.1:2380 --initial-cluster default=http://127.0.0.1:2380 + command: etcd -listen-peer-urls=http://127.0.0.1:12380 -advertise-client-urls=http://127.0.0.1:12379 -listen-client-urls http://0.0.0.0:12379,http://0.0.0.0:14001 -initial-advertise-peer-urls=http://127.0.0.1:12380 --initial-cluster default=http://127.0.0.1:12380 ports: - - "2379:2379" - - "2380:2380" - - "4001:4001" + - "12379:12379" + - "12380:12380" + - "14001:14001" pulsar: image: apachepulsar/pulsar:latest @@ -16,44 +16,44 @@ services: - "6650:6650" - "18080:8080" -# pd0: -# image: pingcap/pd:latest -# network_mode: "host" -# ports: -# - "2379:2379" -# - "2380:2380" -# volumes: -# - /tmp/config/pd.toml:/pd.toml:ro -# - /tmp/data:/data -# - /tmp/logs:/logs -# - /etc/localtime:/etc/localtime:ro -# command: -# - --name=pd0 -# - --client-urls=http://0.0.0.0:2379 -# - --peer-urls=http://0.0.0.0:2380 -# - --advertise-client-urls=http://127.0.0.1:2379 -# - --advertise-peer-urls=http://127.0.0.1:2380 -# - --initial-cluster=pd0=http://127.0.0.1:2380 -# - --data-dir=/data/pd0 -# - --log-file=/logs/pd0.log -# restart: on-failure -# -# tikv0: -# network_mode: "host" -# image: pingcap/tikv:latest -# ports: -# - "20160:20160" -# volumes: -# - /tmp/config/tikv.toml:/tikv.toml:ro -# - /tmp/data:/data -# - /tmp/logs:/logs -# - /etc/localtime:/etc/localtime:ro -# command: -# - --addr=0.0.0.0:20160 -# - --advertise-addr=127.0.0.1:20160 -# - --data-dir=/data/tikv0 -# - --pd=127.0.0.1:2379 -# - --log-file=/logs/tikv0.log -# depends_on: -# - "pd0" -# restart: on-failure + pd0: + image: pingcap/pd:latest + network_mode: "host" + ports: + - "2379:2379" + - "2380:2380" + volumes: + - /tmp/config/pd.toml:/pd.toml:ro + - /tmp/data:/data + - /tmp/logs:/logs + - /etc/localtime:/etc/localtime:ro + command: + - --name=pd0 + - --client-urls=http://0.0.0.0:2379 + - --peer-urls=http://0.0.0.0:2380 + - --advertise-client-urls=http://127.0.0.1:2379 + - --advertise-peer-urls=http://127.0.0.1:2380 + - --initial-cluster=pd0=http://127.0.0.1:2380 + - --data-dir=/data/pd0 + - --log-file=/logs/pd0.log + restart: on-failure + + tikv0: + network_mode: "host" + image: pingcap/tikv:latest + ports: + - "20160:20160" + volumes: + - /tmp/config/tikv.toml:/tikv.toml:ro + - /tmp/data:/data + - /tmp/logs:/logs + - /etc/localtime:/etc/localtime:ro + command: + - --addr=0.0.0.0:20160 + - --advertise-addr=127.0.0.1:20160 + - --data-dir=/data/tikv0 + - --pd=127.0.0.1:2379 + - --log-file=/logs/tikv0.log + depends_on: + - "pd0" + restart: on-failure diff --git a/docs/developer_guides/developer_guides.md b/docs/developer_guides/developer_guides.md index c2b21bbf07..b28d507181 100644 --- a/docs/developer_guides/developer_guides.md +++ b/docs/developer_guides/developer_guides.md @@ -1169,15 +1169,14 @@ type softTimeTickBarrier struct { minTtInterval Timestamp lastTt Timestamp outTt chan Timestamp - ttStream MsgStream + ttStream *MsgStream ctx context.Context } func (ttBarrier *softTimeTickBarrier) GetTimeTick() (Timestamp,error) func (ttBarrier *softTimeTickBarrier) Start() error -func (ttBarrier *softTimeTickBarrier) Close() -func NewSoftTimeTickBarrier(ctx context.Context, ttStream *MsgStream, peerIds []UniqueID, minTtInterval Timestamp) *softTimeTickBarrier +func newSoftTimeTickBarrier(ctx context.Context, ttStream *MsgStream, peerIds []UniqueId, minTtInterval Timestamp) *softTimeTickBarrier ``` @@ -1190,15 +1189,14 @@ func NewSoftTimeTickBarrier(ctx context.Context, ttStream *MsgStream, peerIds [] type hardTimeTickBarrier struct { peer2Tt map[UniqueId]List outTt chan Timestamp - ttStream MsgStream + ttStream *MsgStream ctx context.Context } func (ttBarrier *hardTimeTickBarrier) GetTimeTick() (Timestamp,error) func (ttBarrier *hardTimeTickBarrier) Start() error -func (ttBarrier *hardTimeTickBarrier) Close() -func NewHardTimeTickBarrier(ctx context.Context, ttStream *MsgStream, peerIds []UniqueID) *hardTimeTickBarrier +func newHardTimeTickBarrier(ctx context.Context, ttStream *MsgStream, peerIds []UniqueId) *softTimeTickBarrier ``` @@ -1212,7 +1210,6 @@ func NewHardTimeTickBarrier(ctx context.Context, ttStream *MsgStream, peerIds [] type TimeTickBarrier interface { GetTimeTick() (Timestamp,error) Start() error - Close() } type timeSyncMsgProducer struct { diff --git a/go.mod b/go.mod index 1eb6a53930..d66b97adf5 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/coreos/etcd v3.3.25+incompatible // indirect github.com/frankban/quicktest v1.10.2 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/gogo/protobuf v1.3.1 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.3.2 github.com/google/btree v1.0.0 diff --git a/go.sum b/go.sum index bc892d5f82..5556e6f74d 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,6 @@ github.com/aws/aws-sdk-go v1.30.8 h1:4BHbh8K3qKmcnAgToZ2LShldRF9inoqIBccpCLNCy3I github.com/aws/aws-sdk-go v1.30.8/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= @@ -69,8 +68,9 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v0.0.0-20180717141946-636bf0302bc9/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -133,6 +133,7 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.10.8/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.10.11 h1:K9z59aO18Aywg2b/WSgBaUX99mHy2BES18Cr5lBKZHk= @@ -195,7 +196,6 @@ github.com/ozonru/etcd v3.3.20-grpc1.27-origmodule+incompatible/go.mod h1:iIubIL github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.5.2+incompatible h1:WCjObylUIOlKy/+7Abdn34TLIkXiA4UWUMhxq9m9ZXI= github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 h1:USx2/E1bX46VG32FIw034Au6seQ2fY9NEILmNh/UlQg= github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712 h1:R8gStypOBmpnHEx1qi//SaqxJVI4inOqljg/Aj5/390= github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712/go.mod h1:PYMCGwN0JHjoqGr3HrZoD+b8Tgx8bKnArhSq8YVzUMc= @@ -223,25 +223,21 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= -github.com/prometheus/client_golang v1.0.0 h1:vrDKnkGzuGvhNAL56c7DBz29ZL+KxnoR0x7enabFceM= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.5.1 h1:bdHYieyGlH+6OLEk2YQha8THib30KP0/yD0YH9m6xcA= github.com/prometheus/client_golang v1.5.1/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.4.1 h1:K0MGApIoQvMw27RTdJkPbr3JZ7DNbtxQNyi5STVM6Kw= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.2 h1:6LJUbpNm42llc4HRCuvApCSWB/WfhuNo9K98Q9sNGfs= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= @@ -391,6 +387,7 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqG golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= diff --git a/internal/master/README.md b/internal/master/README.md index 343cef711d..90493cef77 100644 --- a/internal/master/README.md +++ b/internal/master/README.md @@ -22,4 +22,3 @@ go run cmd/master.go ### example if master create a collection with uuid ```46e468ee-b34a-419d-85ed-80c56bfa4e90``` the corresponding key in etcd is $(ETCD_ROOT_PATH)/collection/46e468ee-b34a-419d-85ed-80c56bfa4e90 - diff --git a/internal/master/collection/collection.go b/internal/master/collection/collection.go index 281c32299a..097b4ae8b9 100644 --- a/internal/master/collection/collection.go +++ b/internal/master/collection/collection.go @@ -5,7 +5,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/util/typeutil" - "github.com/golang/protobuf/proto" + "github.com/gogo/protobuf/proto" jsoniter "github.com/json-iterator/go" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" diff --git a/internal/master/collection_task.go b/internal/master/collection_task.go index afa9f260fe..e6207e41a2 100644 --- a/internal/master/collection_task.go +++ b/internal/master/collection_task.go @@ -1,11 +1,11 @@ package master import ( + "encoding/json" "errors" "log" "strconv" - "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -60,41 +60,53 @@ func (t *createCollectionTask) Ts() (Timestamp, error) { if t.req == nil { return 0, errors.New("null request") } - return t.req.Timestamp, nil + return Timestamp(t.req.Timestamp), nil } func (t *createCollectionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } var schema schemapb.CollectionSchema - err := proto.UnmarshalMerge(t.req.Schema.Value, &schema) + err := json.Unmarshal(t.req.Schema.Value, &schema) if err != nil { - return err + _ = t.Notify() + return errors.New("unmarshal CollectionSchema failed") } - collectionId, err := allocGlobalId() - if err != nil { - return err - } - - ts, err := t.Ts() - if err != nil { - return err - } + // TODO: allocate collection id + var collectionId UniqueID = 0 + // TODO: allocate timestamp + var collectionCreateTime Timestamp = 0 collection := etcdpb.CollectionMeta{ Id: collectionId, Schema: &schema, - CreateTime: ts, + CreateTime: collectionCreateTime, // TODO: initial segment? SegmentIds: make([]UniqueID, 0), // TODO: initial partition? PartitionTags: make([]string, 0), } - return t.mt.AddCollection(&collection) + collectionJson, err := json.Marshal(&collection) + if err != nil { + _ = t.Notify() + return errors.New("marshal collection failed") + } + + err = (*t.kvBase).Save(collectionMetaPrefix+strconv.FormatInt(collectionId, 10), string(collectionJson)) + if err != nil { + _ = t.Notify() + return errors.New("save collection failed") + } + + t.mt.collId2Meta[collectionId] = collection + + _ = t.Notify() + return nil } ////////////////////////////////////////////////////////////////////////// @@ -115,12 +127,14 @@ func (t *dropCollectionTask) Ts() (Timestamp, error) { func (t *dropCollectionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } collectionName := t.req.CollectionName.CollectionName collectionMeta, err := t.mt.GetCollectionByName(collectionName) if err != nil { + _ = t.Notify() return err } @@ -128,11 +142,13 @@ func (t *dropCollectionTask) Execute() error { err = (*t.kvBase).Remove(collectionMetaPrefix + strconv.FormatInt(collectionId, 10)) if err != nil { - return err + _ = t.Notify() + return errors.New("save collection failed") } delete(t.mt.collId2Meta, collectionId) + _ = t.Notify() return nil } @@ -154,6 +170,7 @@ func (t *hasCollectionTask) Ts() (Timestamp, error) { func (t *hasCollectionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -163,6 +180,7 @@ func (t *hasCollectionTask) Execute() error { t.hasCollection = true } + _ = t.Notify() return nil } @@ -184,12 +202,14 @@ func (t *describeCollectionTask) Ts() (Timestamp, error) { func (t *describeCollectionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } collectionName := t.req.CollectionName collection, err := t.mt.GetCollectionByName(collectionName.CollectionName) if err != nil { + _ = t.Notify() return err } @@ -202,6 +222,7 @@ func (t *describeCollectionTask) Execute() error { t.description = &description + _ = t.Notify() return nil } @@ -223,6 +244,7 @@ func (t *showCollectionsTask) Ts() (Timestamp, error) { func (t *showCollectionsTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -240,5 +262,6 @@ func (t *showCollectionsTask) Execute() error { t.stringListResponse = &stringListResponse + _ = t.Notify() return nil } diff --git a/internal/master/grpc_service.go b/internal/master/grpc_service.go index 78e9a6bc81..caf16239e9 100644 --- a/internal/master/grpc_service.go +++ b/internal/master/grpc_service.go @@ -2,7 +2,6 @@ package master import ( "context" - "github.com/zilliztech/milvus-distributed/internal/master/tso" "time" "github.com/zilliztech/milvus-distributed/internal/errors" @@ -19,13 +18,14 @@ func (s *Master) CreateCollection(ctx context.Context, in *internalpb.CreateColl req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { + err := errors.New("Enqueue failed") return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: "Enqueue failed", @@ -34,9 +34,10 @@ func (s *Master) CreateCollection(ctx context.Context, in *internalpb.CreateColl err = t.WaitToFinish(ctx) if err != nil { + err := errors.New("WaitToFinish failed") return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "create collection failed", + Reason: "WaitToFinish failed", }, err } @@ -50,12 +51,12 @@ func (s *Master) DropCollection(ctx context.Context, in *internalpb.DropCollecti req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return &commonpb.Status{ @@ -83,13 +84,13 @@ func (s *Master) HasCollection(ctx context.Context, in *internalpb.HasCollection req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, hasCollection: false, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return &servicepb.BoolResponse{ @@ -126,13 +127,13 @@ func (s *Master) DescribeCollection(ctx context.Context, in *internalpb.Describe req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, description: nil, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return t.(*describeCollectionTask).description, err @@ -152,13 +153,13 @@ func (s *Master) ShowCollections(ctx context.Context, in *internalpb.ShowCollect req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, stringListResponse: nil, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return t.(*showCollectionsTask).stringListResponse, err @@ -179,12 +180,12 @@ func (s *Master) CreatePartition(ctx context.Context, in *internalpb.CreateParti req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return &commonpb.Status{ @@ -212,12 +213,12 @@ func (s *Master) DropPartition(ctx context.Context, in *internalpb.DropPartition req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return &commonpb.Status{ @@ -245,13 +246,13 @@ func (s *Master) HasPartition(ctx context.Context, in *internalpb.HasPartitionRe req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, hasPartition: false, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return &servicepb.BoolResponse{ @@ -288,13 +289,13 @@ func (s *Master) DescribePartition(ctx context.Context, in *internalpb.DescribeP req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, description: nil, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return t.(*describePartitionTask).description, err @@ -314,13 +315,13 @@ func (s *Master) ShowPartitions(ctx context.Context, in *internalpb.ShowPartitio req: in, baseTask: baseTask{ kvBase: s.kvBase, - mt: s.mt, - cv: make(chan error), + mt: &s.mt, + cv: make(chan int), }, stringListResponse: nil, } - var err = s.scheduler.Enqueue(t) + var err = s.scheduler.Enqueue(&t) if err != nil { err := errors.New("Enqueue failed") return t.(*showPartitionTask).stringListResponse, err @@ -339,7 +340,7 @@ func (s *Master) ShowPartitions(ctx context.Context, in *internalpb.ShowPartitio func (s *Master) AllocTimestamp(ctx context.Context, request *internalpb.TsoRequest) (*internalpb.TsoResponse, error) { count := request.GetCount() - ts, err := tso.Alloc(count) + ts, err := s.tsoAllocator.GenerateTSO(count) if err != nil { return &internalpb.TsoResponse{ diff --git a/internal/master/grpc_service_test.go b/internal/master/grpc_service_test.go deleted file mode 100644 index 7736b53f75..0000000000 --- a/internal/master/grpc_service_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package master - -import ( - "context" - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" - "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" - "go.etcd.io/etcd/clientv3" - "google.golang.org/grpc" - "testing" -) - -func TestMaster_CreateCollection(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - etcdCli, err := clientv3.New(clientv3.Config{Endpoints: []string{"127.0.0.1:2379"}}) - assert.Nil(t, err) - _, err = etcdCli.Delete(ctx, "/test/root", clientv3.WithPrefix()) - assert.Nil(t, err) - - svr, err := CreateServer(ctx, "/test/root/kv", "/test/root/meta", "/test/root/meta/tso", []string{"127.0.0.1:2379"}) - assert.Nil(t, err) - err = svr.Run(10001) - assert.Nil(t, err) - - conn, err := grpc.DialContext(ctx, "127.0.0.1:10001", grpc.WithInsecure(), grpc.WithBlock()) - assert.Nil(t, err) - defer conn.Close() - - cli := masterpb.NewMasterClient(conn) - sch := schemapb.CollectionSchema{ - Name: "col1", - Description: "test collection", - AutoId: false, - Fields: []*schemapb.FieldSchema{ - { - Name: "col1_f1", - Description: "test collection filed 1", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "col1_f1_tk1", - Value: "col1_f1_tv1", - }, - { - Key: "col1_f1_tk2", - Value: "col1_f1_tv2", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "col1_f1_ik1", - Value: "col1_f1_iv1", - }, - { - Key: "col1_f1_ik2", - Value: "col1_f1_iv2", - }, - }, - }, - { - Name: "col1_f2", - Description: "test collection filed 2", - DataType: schemapb.DataType_VECTOR_BINARY, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "col1_f2_tk1", - Value: "col1_f2_tv1", - }, - { - Key: "col1_f2_tk2", - Value: "col1_f2_tv2", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "col1_f2_ik1", - Value: "col1_f2_iv1", - }, - { - Key: "col1_f2_ik2", - Value: "col1_f2_iv2", - }, - }, - }, - }, - } - schema_bytes, err := proto.Marshal(&sch) - assert.Nil(t, err) - - req := internalpb.CreateCollectionRequest{ - MsgType: internalpb.MsgType_kCreateCollection, - ReqId: 1, - Timestamp: 11, - ProxyId: 1, - Schema: &commonpb.Blob{Value: schema_bytes}, - } - st, err := cli.CreateCollection(ctx, &req) - assert.Nil(t, err) - assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) - - coll_meta, err := svr.mt.GetCollectionByName(sch.Name) - assert.Nil(t, err) - t.Logf("collection id = %d", coll_meta.Id) - assert.Equal(t, coll_meta.CreateTime, uint64(11)) - assert.Equal(t, coll_meta.Schema.Name, "col1") - assert.Equal(t, coll_meta.Schema.AutoId, false) - assert.Equal(t, len(coll_meta.Schema.Fields), 2) - assert.Equal(t, coll_meta.Schema.Fields[0].Name, "col1_f1") - assert.Equal(t, coll_meta.Schema.Fields[1].Name, "col1_f2") - assert.Equal(t, coll_meta.Schema.Fields[0].DataType, schemapb.DataType_VECTOR_FLOAT) - assert.Equal(t, coll_meta.Schema.Fields[1].DataType, schemapb.DataType_VECTOR_BINARY) - assert.Equal(t, len(coll_meta.Schema.Fields[0].TypeParams), 2) - assert.Equal(t, len(coll_meta.Schema.Fields[0].IndexParams), 2) - assert.Equal(t, len(coll_meta.Schema.Fields[1].TypeParams), 2) - assert.Equal(t, len(coll_meta.Schema.Fields[1].IndexParams), 2) - assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[0].Key, "col1_f1_tk1") - assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[1].Key, "col1_f1_tk2") - assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[0].Value, "col1_f1_tv1") - assert.Equal(t, coll_meta.Schema.Fields[0].TypeParams[1].Value, "col1_f1_tv2") - assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[0].Key, "col1_f1_ik1") - assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[1].Key, "col1_f1_ik2") - assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[0].Value, "col1_f1_iv1") - assert.Equal(t, coll_meta.Schema.Fields[0].IndexParams[1].Value, "col1_f1_iv2") - - assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[0].Key, "col1_f2_tk1") - assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[1].Key, "col1_f2_tk2") - assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[0].Value, "col1_f2_tv1") - assert.Equal(t, coll_meta.Schema.Fields[1].TypeParams[1].Value, "col1_f2_tv2") - assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[0].Key, "col1_f2_ik1") - assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[1].Key, "col1_f2_ik2") - assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[0].Value, "col1_f2_iv1") - assert.Equal(t, coll_meta.Schema.Fields[1].IndexParams[1].Value, "col1_f2_iv2") - - svr.Close() -} diff --git a/internal/master/id/id.go b/internal/master/id/id.go index 8f584f3f20..4c4466e0a5 100644 --- a/internal/master/id/id.go +++ b/internal/master/id/id.go @@ -1,12 +1,13 @@ + package id import ( "github.com/zilliztech/milvus-distributed/internal/kv" "github.com/zilliztech/milvus-distributed/internal/master/tso" - "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) + type UniqueID = typeutil.UniqueID // GlobalTSOAllocator is the global single point TSO allocator. @@ -16,18 +17,13 @@ type GlobalIdAllocator struct { var allocator *GlobalIdAllocator -func Init() { - InitGlobalIdAllocator("idTimestamp", tsoutil.NewTSOKVBase("gid")) -} - -func InitGlobalIdAllocator(key string, base kv.KVBase) { +func InitGlobalIdAllocator(key string, base kv.KVBase){ allocator = NewGlobalIdAllocator(key, base) - allocator.Initialize() } -func NewGlobalIdAllocator(key string, base kv.KVBase) *GlobalIdAllocator { +func NewGlobalIdAllocator(key string, base kv.KVBase) * GlobalIdAllocator{ return &GlobalIdAllocator{ - allocator: tso.NewGlobalTSOAllocator(key, base), + allocator: tso.NewGlobalTSOAllocator( key, base), } } diff --git a/internal/master/id/id_test.go b/internal/master/id/id_test.go index 150af1d5c8..c80c75bc5f 100644 --- a/internal/master/id/id_test.go +++ b/internal/master/id/id_test.go @@ -1,19 +1,17 @@ package id import ( - "os" - "testing" - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/conf" - "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" + "github.com/zilliztech/milvus-distributed/internal/kv/mockkv" + "os" + + "testing" ) var GIdAllocator *GlobalIdAllocator func TestMain(m *testing.M) { - conf.LoadConfig("config.yaml") - GIdAllocator = NewGlobalIdAllocator("idTimestamp", tsoutil.NewTSOKVBase("gid")) + GIdAllocator = NewGlobalIdAllocator("idTimestamp", mockkv.NewEtcdKV()) exitCode := m.Run() os.Exit(exitCode) } @@ -32,8 +30,8 @@ func TestGlobalIdAllocator_AllocOne(t *testing.T) { } func TestGlobalIdAllocator_Alloc(t *testing.T) { - count := uint32(2 << 10) + count := uint32(2<<10) idStart, idEnd, err := GIdAllocator.Alloc(count) assert.Nil(t, err) - assert.Equal(t, count, uint32(idEnd-idStart)) -} + assert.Equal(t, count, uint32(idEnd - idStart)) +} \ No newline at end of file diff --git a/internal/master/master.go b/internal/master/master.go index f21e5ea471..506dd1b8f7 100644 --- a/internal/master/master.go +++ b/internal/master/master.go @@ -3,20 +3,18 @@ package master import ( "context" "fmt" - "github.com/zilliztech/milvus-distributed/internal/errors" "log" "math/rand" "net" + "path" "strconv" "sync" "sync/atomic" "time" - "github.com/zilliztech/milvus-distributed/internal/master/id" - "github.com/zilliztech/milvus-distributed/internal/master/tso" - "github.com/apache/pulsar-client-go/pulsar" "github.com/golang/protobuf/proto" + "github.com/zilliztech/milvus-distributed/internal/master/id" "github.com/zilliztech/milvus-distributed/internal/conf" "github.com/zilliztech/milvus-distributed/internal/kv" "github.com/zilliztech/milvus-distributed/internal/master/controller" @@ -25,6 +23,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" "google.golang.org/grpc" + "github.com/zilliztech/milvus-distributed/internal/master/tso" "go.etcd.io/etcd/clientv3" ) @@ -44,6 +43,9 @@ type Master struct { //grpc server grpcServer *grpc.Server + // for tso. + tsoAllocator tso.Allocator + // pulsar client pc *informer.PulsarClient @@ -52,50 +54,46 @@ type Master struct { kvBase *kv.EtcdKV scheduler *ddRequestScheduler - mt *metaTable + mt metaTable // Add callback functions at different stages startCallbacks []func() closeCallbacks []func() } -func newKVBase(kv_root string, etcdAddr []string) *kv.EtcdKV { - cli, _ := clientv3.New(clientv3.Config{ - Endpoints: etcdAddr, +func newTSOKVBase(subPath string) * kv.EtcdKV{ + etcdAddr := conf.Config.Etcd.Address + etcdAddr += ":" + etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10) + client, _ := clientv3.New(clientv3.Config{ + Endpoints: []string{etcdAddr}, DialTimeout: 5 * time.Second, }) - kvBase := kv.NewEtcdKV(cli, kv_root) + return kv.NewEtcdKV(client, path.Join(conf.Config.Etcd.Rootpath, subPath)) +} + +func newKVBase() *kv.EtcdKV { + etcdAddr := conf.Config.Etcd.Address + etcdAddr += ":" + etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10) + cli, _ := clientv3.New(clientv3.Config{ + Endpoints: []string{etcdAddr}, + DialTimeout: 5 * time.Second, + }) + kvBase := kv.NewEtcdKV(cli, conf.Config.Etcd.Rootpath) return kvBase } -func Init() { - rand.Seed(time.Now().UnixNano()) - id.Init() - tso.Init() -} - // CreateServer creates the UNINITIALIZED pd server with given configuration. -func CreateServer(ctx context.Context, kv_root_path string, meta_root_path, tso_root_path string, etcdAddr []string) (*Master, error) { +func CreateServer(ctx context.Context) (*Master, error) { rand.Seed(time.Now().UnixNano()) - Init() - - etcdClient, err := clientv3.New(clientv3.Config{Endpoints: etcdAddr}) - if err != nil { - return nil, err - } - etcdkv := kv.NewEtcdKV(etcdClient, meta_root_path) - metakv, err := NewMetaTable(etcdkv) - if err != nil { - return nil, err - } - + id.InitGlobalIdAllocator("idTimestamp", newTSOKVBase("gid")) m := &Master{ ctx: ctx, startTimestamp: time.Now().Unix(), - kvBase: newKVBase(kv_root_path, etcdAddr), - scheduler: NewDDRequestScheduler(), - mt: metakv, + kvBase: newKVBase(), ssChan: make(chan internalpb.SegmentStatistics, 10), pc: informer.NewPulsarClient(), + tsoAllocator: tso.NewGlobalTSOAllocator("timestamp", newTSOKVBase("tso")), } m.grpcServer = grpc.NewServer() masterpb.RegisterMasterServer(m.grpcServer, m) @@ -153,13 +151,13 @@ func (s *Master) IsClosed() bool { } // Run runs the pd server. -func (s *Master) Run(grpcPort int64) error { +func (s *Master) Run() error { if err := s.startServer(s.ctx); err != nil { return err } - s.startServerLoop(s.ctx, grpcPort) + s.startServerLoop(s.ctx) return nil } @@ -174,28 +172,18 @@ func (s *Master) LoopContext() context.Context { return s.serverLoopCtx } -func (s *Master) startServerLoop(ctx context.Context, grpcPort int64) { +func (s *Master) startServerLoop(ctx context.Context) { s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(ctx) + s.serverLoopWg.Add(3) //go s.Se - - s.serverLoopWg.Add(1) - go s.grpcLoop(grpcPort) - - //s.serverLoopWg.Add(1) - //go s.pulsarLoop() - - s.serverLoopWg.Add(1) - go s.tasksExecutionLoop() - - s.serverLoopWg.Add(1) + go s.grpcLoop() + go s.pulsarLoop() go s.segmentStatisticsLoop() - } func (s *Master) stopServerLoop() { - if s.grpcServer != nil { + if s.grpcServer != nil{ s.grpcServer.GracefulStop() - log.Printf("server is cloded, exit grpc server") } s.serverLoopCancel() s.serverLoopWg.Wait() @@ -206,11 +194,11 @@ func (s *Master) StartTimestamp() int64 { return s.startTimestamp } -func (s *Master) grpcLoop(grpcPort int64) { +func (s *Master) grpcLoop() { defer s.serverLoopWg.Done() defaultGRPCPort := ":" - defaultGRPCPort += strconv.FormatInt(grpcPort, 10) + defaultGRPCPort += strconv.FormatInt(int64(conf.Config.Master.Port), 10) lis, err := net.Listen("tcp", defaultGRPCPort) if err != nil { log.Printf("failed to listen: %v", err) @@ -257,7 +245,7 @@ func (s *Master) pulsarLoop() { s.ssChan <- m consumer.Ack(msg) case <-ctx.Done(): - log.Print("server is closed, exit pulsar loop") + log.Print("server is closed, exit etcd leader loop") return } } @@ -270,16 +258,18 @@ func (s *Master) tasksExecutionLoop() { for { select { case task := <-s.scheduler.reqQueue: - timeStamp, err := (task).Ts() + timeStamp, err := (*task).Ts() if err != nil { log.Println(err) } else { if timeStamp < s.scheduler.scheduleTimeStamp { - task.Notify(errors.Errorf("input timestamp = %d, schduler timestamp = %d", timeStamp, s.scheduler.scheduleTimeStamp)) + _ = (*task).NotifyTimeout() } else { s.scheduler.scheduleTimeStamp = timeStamp - err = task.Execute() - task.Notify(err) + err := (*task).Execute() + if err != nil { + log.Println("request execution failed caused by error:", err) + } } } case <-ctx.Done(): @@ -300,7 +290,7 @@ func (s *Master) segmentStatisticsLoop() { case ss := <-s.ssChan: controller.ComputeCloseTime(ss, s.kvBase) case <-ctx.Done(): - log.Print("server is closed, exit segment statistics loop") + log.Print("server is closed, exit etcd leader loop") return } } diff --git a/internal/master/partition_task.go b/internal/master/partition_task.go index 733efb9dcb..f53f4e8303 100644 --- a/internal/master/partition_task.go +++ b/internal/master/partition_task.go @@ -59,6 +59,7 @@ func (t *createPartitionTask) Ts() (Timestamp, error) { func (t *createPartitionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -66,6 +67,7 @@ func (t *createPartitionTask) Execute() error { collectionName := partitionName.CollectionName collectionMeta, err := t.mt.GetCollectionByName(collectionName) if err != nil { + _ = t.Notify() return err } @@ -73,15 +75,18 @@ func (t *createPartitionTask) Execute() error { collectionJson, err := json.Marshal(&collectionMeta) if err != nil { - return err + _ = t.Notify() + return errors.New("marshal collection failed") } collectionId := collectionMeta.Id err = (*t.kvBase).Save(partitionMetaPrefix+strconv.FormatInt(collectionId, 10), string(collectionJson)) if err != nil { - return err + _ = t.Notify() + return errors.New("save collection failed") } + _ = t.Notify() return nil } @@ -103,6 +108,7 @@ func (t *dropPartitionTask) Ts() (Timestamp, error) { func (t *dropPartitionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -110,6 +116,7 @@ func (t *dropPartitionTask) Execute() error { collectionName := partitionName.CollectionName collectionMeta, err := t.mt.GetCollectionByName(collectionName) if err != nil { + _ = t.Notify() return err } @@ -120,15 +127,18 @@ func (t *dropPartitionTask) Execute() error { collectionJson, err := json.Marshal(&collectionMeta) if err != nil { - return err + _ = t.Notify() + return errors.New("marshal collection failed") } collectionId := collectionMeta.Id err = (*t.kvBase).Save(partitionMetaPrefix+strconv.FormatInt(collectionId, 10), string(collectionJson)) if err != nil { - return err + _ = t.Notify() + return errors.New("save collection failed") } + _ = t.Notify() return nil } @@ -150,6 +160,7 @@ func (t *hasPartitionTask) Ts() (Timestamp, error) { func (t *hasPartitionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -162,6 +173,7 @@ func (t *hasPartitionTask) Execute() error { t.hasPartition = t.mt.HasPartition(collectionMeta.Id, partitionName.Tag) + _ = t.Notify() return nil } @@ -183,6 +195,7 @@ func (t *describePartitionTask) Ts() (Timestamp, error) { func (t *describePartitionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -197,6 +210,7 @@ func (t *describePartitionTask) Execute() error { t.description = &description + _ = t.Notify() return nil } @@ -218,6 +232,7 @@ func (t *showPartitionTask) Ts() (Timestamp, error) { func (t *showPartitionTask) Execute() error { if t.req == nil { + _ = t.Notify() return errors.New("null request") } @@ -237,5 +252,6 @@ func (t *showPartitionTask) Execute() error { t.stringListResponse = &stringListResponse + _ = t.Notify() return nil } diff --git a/internal/master/scheduler.go b/internal/master/scheduler.go index 6dcc3bc1df..2c13fa3564 100644 --- a/internal/master/scheduler.go +++ b/internal/master/scheduler.go @@ -1,9 +1,7 @@ package master -import "math/rand" - type ddRequestScheduler struct { - reqQueue chan task + reqQueue chan *task scheduleTimeStamp Timestamp } @@ -11,17 +9,12 @@ func NewDDRequestScheduler() *ddRequestScheduler { const channelSize = 1024 rs := ddRequestScheduler{ - reqQueue: make(chan task, channelSize), + reqQueue: make(chan *task, channelSize), } return &rs } -func (rs *ddRequestScheduler) Enqueue(task task) error { +func (rs *ddRequestScheduler) Enqueue(task *task) error { rs.reqQueue <- task return nil } - -//TODO, allocGlobalId -func allocGlobalId() (UniqueID, error) { - return rand.Int63(), nil -} diff --git a/internal/master/task.go b/internal/master/task.go index 0c6bab5f4c..d29dbd8bb1 100644 --- a/internal/master/task.go +++ b/internal/master/task.go @@ -13,7 +13,7 @@ import ( type baseTask struct { kvBase *kv.EtcdKV mt *metaTable - cv chan error + cv chan int } type task interface { @@ -21,23 +21,27 @@ type task interface { Ts() (Timestamp, error) Execute() error WaitToFinish(ctx context.Context) error - Notify(err error) + Notify() error + NotifyTimeout() error } -func (bt *baseTask) Notify(err error) { - bt.cv <- err +func (bt *baseTask) Notify() error { + bt.cv <- 0 + return nil +} + +func (bt *baseTask) NotifyTimeout() error { + bt.cv <- 0 + return errors.New("request timeout") } func (bt *baseTask) WaitToFinish(ctx context.Context) error { for { select { case <-ctx.Done(): - return errors.Errorf("context done") - case err, ok := <-bt.cv: - if !ok { - return errors.Errorf("notify chan closed") - } - return err + return nil + case <-bt.cv: + return nil } } } diff --git a/internal/master/timesync/timesync.go b/internal/master/timesync/timesync.go index dbcabf5e5f..ce11cd1955 100644 --- a/internal/master/timesync/timesync.go +++ b/internal/master/timesync/timesync.go @@ -3,261 +3,213 @@ package timesync import ( "context" "log" - "math" + "sort" + "strconv" + "sync" + "time" - "github.com/zilliztech/milvus-distributed/internal/errors" - ms "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/conf" + + "github.com/apache/pulsar-client-go/pulsar" + "github.com/golang/protobuf/proto" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" ) -type ( - softTimeTickBarrier struct { - peer2LastTt map[UniqueID]Timestamp - minTtInterval Timestamp - lastTt Timestamp - outTt chan Timestamp - ttStream ms.MsgStream - ctx context.Context - closeCh chan struct{} // close goroutinue in Start() - closed bool +const stopReadFlagId int64 = -1 + +type TimeTickReader struct { + pulsarClient pulsar.Client + + timeTickConsumer pulsar.Consumer + readerProducer []pulsar.Producer + + interval int64 + proxyIdList []UniqueID + + timeTickPeerProxy map[UniqueID]Timestamp + ctx context.Context +} + +func (r *TimeTickReader) Start() { + go r.readTimeTick() + go r.timeSync() + +} + +func (r *TimeTickReader) Close() { + if r.timeTickConsumer != nil { + r.timeTickConsumer.Close() } - hardTimeTickBarrier struct { - peer2Tt map[UniqueID]Timestamp - outTt chan Timestamp - ttStream ms.MsgStream - ctx context.Context - closeCh chan struct{} // close goroutinue in Start() - closed bool + for i := 0; i < len(r.readerProducer); i++ { + if r.readerProducer[i] != nil { + r.readerProducer[i].Close() + } } -) + if r.pulsarClient != nil { + r.pulsarClient.Close() + } +} -func (ttBarrier *softTimeTickBarrier) GetTimeTick() (Timestamp, error) { - isEmpty := true +func (r *TimeTickReader) timeSync() { + ctx := r.ctx for { - - if ttBarrier.closed { - return 0, errors.Errorf("[GetTimeTick] closed.") - } - select { - case ts := <-ttBarrier.outTt: - isEmpty = false - ttBarrier.lastTt = ts - + case <-ctx.Done(): + return default: - if isEmpty { - continue + time.Sleep(time.Millisecond * time.Duration(r.interval)) + var minTimeStamp Timestamp + for _, minTimeStamp = range r.timeTickPeerProxy { + break } - return ttBarrier.lastTt, nil - } - } -} - -func (ttBarrier *softTimeTickBarrier) Start() error { - ttBarrier.closeCh = make(chan struct{}) - go func() { - for { - select { - - case <-ttBarrier.closeCh: - log.Printf("[TtBarrierStart] closed\n") - return - - case <-ttBarrier.ctx.Done(): - log.Printf("[TtBarrierStart] %s\n", ttBarrier.ctx.Err()) - ttBarrier.closed = true - return - - case ttmsgs := <-ttBarrier.ttStream.Chan(): - if len(ttmsgs.Msgs) > 0 { - for _, timetickmsg := range ttmsgs.Msgs { - ttmsg := (*timetickmsg).(*ms.TimeTickMsg) - oldT, ok := ttBarrier.peer2LastTt[ttmsg.PeerId] - log.Printf("[softTimeTickBarrier] peer(%d)=%d\n", ttmsg.PeerId, ttmsg.Timestamp) - - if !ok { - log.Printf("[softTimeTickBarrier] Warning: peerId %d not exist\n", ttmsg.PeerId) - continue - } - - if ttmsg.Timestamp > oldT { - ttBarrier.peer2LastTt[ttmsg.PeerId] = ttmsg.Timestamp - - // get a legal Timestamp - ts := ttBarrier.minTimestamp() - - if ttBarrier.lastTt != 0 && ttBarrier.minTtInterval > ts-ttBarrier.lastTt { - continue - } - - ttBarrier.outTt <- ts - } - } + for _, ts := range r.timeTickPeerProxy { + if ts < minTimeStamp { + minTimeStamp = ts } - - default: + } + //send timestamp flag to reader channel + msg := internalpb.TimeTickMsg{ + Timestamp: minTimeStamp, + MsgType: internalpb.MsgType_kTimeTick, + } + payload, err := proto.Marshal(&msg) + if err != nil { + //TODO log error + log.Printf("Marshal InsertOrDeleteMsg flag error %v", err) + } else { + wg := sync.WaitGroup{} + wg.Add(len(r.readerProducer)) + for index := range r.readerProducer { + go r.sendEOFMsg(ctx, &pulsar.ProducerMessage{Payload: payload}, index, &wg) + } + wg.Wait() } } - }() - return nil + } } -func NewSoftTimeTickBarrier(ctx context.Context, - ttStream *ms.MsgStream, - peerIds []UniqueID, - minTtInterval Timestamp) *softTimeTickBarrier { - - if len(peerIds) <= 0 { - log.Printf("[NewSoftTimeTickBarrier] Error: peerIds is emtpy!\n") - return nil - } - - sttbarrier := softTimeTickBarrier{} - sttbarrier.minTtInterval = minTtInterval - sttbarrier.ttStream = *ttStream - sttbarrier.outTt = make(chan Timestamp, 1024) - sttbarrier.ctx = ctx - sttbarrier.closed = false - - sttbarrier.peer2LastTt = make(map[UniqueID]Timestamp) - for _, id := range peerIds { - sttbarrier.peer2LastTt[id] = Timestamp(0) - } - if len(peerIds) != len(sttbarrier.peer2LastTt) { - log.Printf("[NewSoftTimeTickBarrier] Warning: there are duplicate peerIds!\n") - } - - return &sttbarrier -} - -func (ttBarrier *softTimeTickBarrier) Close() { - - if ttBarrier.closeCh != nil { - ttBarrier.closeCh <- struct{}{} - } - - ttBarrier.closed = true -} - -func (ttBarrier *softTimeTickBarrier) minTimestamp() Timestamp { - tempMin := Timestamp(math.MaxUint64) - for _, tt := range ttBarrier.peer2LastTt { - if tt < tempMin { - tempMin = tt - } - } - return tempMin -} - -func (ttBarrier *hardTimeTickBarrier) GetTimeTick() (Timestamp, error) { +func (r *TimeTickReader) readTimeTick() { for { - - if ttBarrier.closed { - return 0, errors.Errorf("[GetTimeTick] closed.") - } - select { - case ts := <-ttBarrier.outTt: - return ts, nil - default: - } - } -} - -func (ttBarrier *hardTimeTickBarrier) Start() error { - ttBarrier.closeCh = make(chan struct{}) - - go func() { - // Last timestamp synchronized - state := Timestamp(0) - for { - select { - - case <-ttBarrier.closeCh: - log.Printf("[TtBarrierStart] closed\n") - return - - case <-ttBarrier.ctx.Done(): - log.Printf("[TtBarrierStart] %s\n", ttBarrier.ctx.Err()) - ttBarrier.closed = true - return - - case ttmsgs := <-ttBarrier.ttStream.Chan(): - if len(ttmsgs.Msgs) > 0 { - for _, timetickmsg := range ttmsgs.Msgs { - - // Suppose ttmsg.Timestamp from stream is always larger than the previous one, - // that `ttmsg.Timestamp > oldT` - ttmsg := (*timetickmsg).(*ms.TimeTickMsg) - log.Printf("[hardTimeTickBarrier] peer(%d)=%d\n", ttmsg.PeerId, ttmsg.Timestamp) - - oldT, ok := ttBarrier.peer2Tt[ttmsg.PeerId] - if !ok { - log.Printf("[hardTimeTickBarrier] Warning: peerId %d not exist\n", ttmsg.PeerId) - continue - } - - if oldT > state { - log.Printf("[hardTimeTickBarrier] Warning: peer(%d) timestamp(%d) ahead\n", - ttmsg.PeerId, ttmsg.Timestamp) - } - - ttBarrier.peer2Tt[ttmsg.PeerId] = ttmsg.Timestamp - - newState := ttBarrier.minTimestamp() - if newState > state { - ttBarrier.outTt <- newState - state = newState - } - } - } - default: + case <-r.ctx.Done(): + return + case cm, ok := <-r.timeTickConsumer.Chan(): + if ok == false { + log.Printf("timesync consumer closed") } - } - }() - return nil -} -func (ttBarrier *hardTimeTickBarrier) minTimestamp() Timestamp { - tempMin := Timestamp(math.MaxUint64) - for _, tt := range ttBarrier.peer2Tt { - if tt < tempMin { - tempMin = tt + msg := cm.Message + var tsm internalpb.TimeTickMsg + if err := proto.Unmarshal(msg.Payload(), &tsm); err != nil { + log.Printf("UnMarshal timetick flag error %v", err) + } + + r.timeTickPeerProxy[tsm.PeerId] = tsm.Timestamp + r.timeTickConsumer.AckID(msg.ID()) } } - return tempMin } -func NewHardTimeTickBarrier(ctx context.Context, - ttStream *ms.MsgStream, - peerIds []UniqueID) *hardTimeTickBarrier { - - if len(peerIds) <= 0 { - log.Printf("[NewSoftTimeTickBarrier] Error: peerIds is emtpy!") - return nil +func (r *TimeTickReader) sendEOFMsg(ctx context.Context, msg *pulsar.ProducerMessage, index int, wg *sync.WaitGroup) { + if _, err := r.readerProducer[index].Send(ctx, msg); err != nil { + log.Printf("Send timesync flag error %v", err) } - - sttbarrier := hardTimeTickBarrier{} - sttbarrier.ttStream = *ttStream - sttbarrier.outTt = make(chan Timestamp, 1024) - sttbarrier.ctx = ctx - sttbarrier.closed = false - - sttbarrier.peer2Tt = make(map[UniqueID]Timestamp) - for _, id := range peerIds { - sttbarrier.peer2Tt[id] = Timestamp(0) - } - if len(peerIds) != len(sttbarrier.peer2Tt) { - log.Printf("[NewSoftTimeTickBarrier] Warning: there are duplicate peerIds!") - } - - return &sttbarrier + wg.Done() } -func (ttBarrier *hardTimeTickBarrier) Close() { - if ttBarrier.closeCh != nil { - ttBarrier.closeCh <- struct{}{} +func TimeTickService() { + timeTickTopic := "timeTick" + timeTickSubName := "master" + readTopics := make([]string, 0) + for i := conf.Config.Reader.TopicStart; i < conf.Config.Reader.TopicEnd; i++ { + str := "InsertOrDelete-" + str = str + strconv.Itoa(i) + readTopics = append(readTopics, str) } - ttBarrier.closed = true - return + + proxyIdList := conf.Config.Master.ProxyIdList + timeTickReader := newTimeTickReader(context.Background(), timeTickTopic, timeTickSubName, readTopics, proxyIdList) + timeTickReader.Start() +} + +func newTimeTickReader( + ctx context.Context, + timeTickTopic string, + timeTickSubName string, + readTopics []string, + proxyIdList []UniqueID, +) *TimeTickReader { + pulsarAddr := "pulsar://" + pulsarAddr += conf.Config.Pulsar.Address + pulsarAddr += ":" + pulsarAddr += strconv.FormatInt(int64(conf.Config.Pulsar.Port), 10) + interval := int64(conf.Config.Timesync.Interval) + + //check if proxyId has duplication + if len(proxyIdList) == 0 { + log.Printf("proxy id list is empty") + } + if len(proxyIdList) > 1 { + sort.Slice(proxyIdList, func(i int, j int) bool { return proxyIdList[i] < proxyIdList[j] }) + } + for i := 1; i < len(proxyIdList); i++ { + if proxyIdList[i] == proxyIdList[i-1] { + log.Printf("there are two proxies have the same id = %d", proxyIdList[i]) + } + } + r := TimeTickReader{} + r.interval = interval + r.proxyIdList = proxyIdList + readerQueueSize := conf.Config.Reader.ReaderQueueSize + + //check if read topic is empty + if len(readTopics) == 0 { + log.Printf("read topic is empyt") + } + //set default value + if readerQueueSize == 0 { + readerQueueSize = 1024 + } + + r.timeTickPeerProxy = make(map[UniqueID]Timestamp) + r.ctx = ctx + + var client pulsar.Client + var err error + if conf.Config.Pulsar.Authentication { + client, err = pulsar.NewClient(pulsar.ClientOptions{ + URL: pulsarAddr, + Authentication: pulsar.NewAuthenticationToken(conf.Config.Pulsar.Token), + }) + } else { + client, err = pulsar.NewClient(pulsar.ClientOptions{URL: pulsarAddr}) + } + + if err != nil { + log.Printf("connect pulsar failed, %v", err) + } + r.pulsarClient = client + + timeSyncChan := make(chan pulsar.ConsumerMessage, len(r.proxyIdList)) + if r.timeTickConsumer, err = r.pulsarClient.Subscribe(pulsar.ConsumerOptions{ + Topic: timeTickTopic, + SubscriptionName: timeTickSubName, + Type: pulsar.KeyShared, + SubscriptionInitialPosition: pulsar.SubscriptionPositionEarliest, + MessageChannel: timeSyncChan, + }); err != nil { + log.Printf("failed to subscribe topic %s, error = %v", timeTickTopic, err) + } + + r.readerProducer = make([]pulsar.Producer, 0, len(readTopics)) + for i := 0; i < len(readTopics); i++ { + rp, err := r.pulsarClient.CreateProducer(pulsar.ProducerOptions{Topic: readTopics[i]}) + if err != nil { + log.Printf("failed to create reader producer %s, error = %v", readTopics[i], err) + } + r.readerProducer = append(r.readerProducer, rp) + } + + return &r } diff --git a/internal/master/timesync/timesync_test.go b/internal/master/timesync/timesync_test.go deleted file mode 100644 index 06ad10162a..0000000000 --- a/internal/master/timesync/timesync_test.go +++ /dev/null @@ -1,426 +0,0 @@ -package timesync - -import ( - "context" - "log" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - ms "github.com/zilliztech/milvus-distributed/internal/msgstream" - internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" -) - -func getTtMsg(msgType internalPb.MsgType, peerId UniqueID, timeStamp uint64) *ms.TsMsg { - var tsMsg ms.TsMsg - baseMsg := ms.BaseMsg{ - HashValues: []int32{int32(peerId)}, - } - timeTickResult := internalPb.TimeTickMsg{ - MsgType: internalPb.MsgType_kTimeTick, - PeerId: peerId, - Timestamp: timeStamp, - } - timeTickMsg := &ms.TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, - } - tsMsg = timeTickMsg - return &tsMsg -} - -func initPulsarStream(pulsarAddress string, - producerChannels []string, - consumerChannels []string, - consumerSubName string) (*ms.MsgStream, *ms.MsgStream) { - - // set input stream - inputStream := ms.NewPulsarMsgStream(context.Background(), 100) - inputStream.SetPulsarCient(pulsarAddress) - inputStream.CreatePulsarProducers(producerChannels) - var input ms.MsgStream = inputStream - - // set output stream - outputStream := ms.NewPulsarMsgStream(context.Background(), 100) - outputStream.SetPulsarCient(pulsarAddress) - unmarshalDispatcher := ms.NewUnmarshalDispatcher() - outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) - outputStream.Start() - var output ms.MsgStream = outputStream - - return &input, &output -} - -func getMsgPack(ttmsgs [][2]int) *ms.MsgPack { - msgPack := ms.MsgPack{} - for _, vi := range ttmsgs { - msgPack.Msgs = append(msgPack.Msgs, getTtMsg(internalPb.MsgType_kTimeTick, UniqueID(vi[0]), Timestamp(vi[1]))) - } - return &msgPack -} - -func getEmptyMsgPack() *ms.MsgPack { - msgPack := ms.MsgPack{} - return &msgPack -} - -func producer(channels []string, ttmsgs [][2]int) (*ms.MsgStream, *ms.MsgStream) { - pulsarAddress := "pulsar://localhost:6650" - consumerSubName := "subTimetick" - producerChannels := channels - consumerChannels := channels - - inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - - msgPackAddr := getMsgPack(ttmsgs) - (*inputStream).Produce(msgPackAddr) - return inputStream, outputStream -} - -func TestTt_NewSoftTtBarrier(t *testing.T) { - channels := []string{"NewSoftTtBarrier"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - minTtInterval := Timestamp(10) - - validPeerIds := []UniqueID{1, 2, 3} - - sttbarrier := NewSoftTimeTickBarrier(context.TODO(), ttStream, validPeerIds, minTtInterval) - assert.NotNil(t, sttbarrier) - sttbarrier.Close() - - validPeerIds2 := []UniqueID{1, 1, 1} - sttbarrier = NewSoftTimeTickBarrier(context.TODO(), ttStream, validPeerIds2, minTtInterval) - assert.NotNil(t, sttbarrier) - sttbarrier.Close() - - // invalid peerIds - invalidPeerIds1 := make([]UniqueID, 0, 3) - sttbarrier = NewSoftTimeTickBarrier(context.TODO(), ttStream, invalidPeerIds1, minTtInterval) - assert.Nil(t, sttbarrier) - - invalidPeerIds2 := []UniqueID{} - sttbarrier = NewSoftTimeTickBarrier(context.TODO(), ttStream, invalidPeerIds2, minTtInterval) - assert.Nil(t, sttbarrier) -} - -func TestTt_NewHardTtBarrier(t *testing.T) { - channels := []string{"NewHardTtBarrier"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - validPeerIds := []UniqueID{1, 2, 3} - - sttbarrier := NewHardTimeTickBarrier(context.TODO(), ttStream, validPeerIds) - assert.NotNil(t, sttbarrier) - sttbarrier.Close() - - validPeerIds2 := []UniqueID{1, 1, 1} - sttbarrier = NewHardTimeTickBarrier(context.TODO(), ttStream, validPeerIds2) - assert.NotNil(t, sttbarrier) - sttbarrier.Close() - - // invalid peerIds - invalidPeerIds1 := make([]UniqueID, 0, 3) - sttbarrier = NewHardTimeTickBarrier(context.TODO(), ttStream, invalidPeerIds1) - assert.Nil(t, sttbarrier) - - invalidPeerIds2 := []UniqueID{} - sttbarrier = NewHardTimeTickBarrier(context.TODO(), ttStream, invalidPeerIds2) - assert.Nil(t, sttbarrier) -} - -func TestTt_SoftTtBarrierStart(t *testing.T) { - channels := []string{"SoftTtBarrierStart"} - - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - minTtInterval := Timestamp(10) - peerIds := []UniqueID{1, 2, 3} - sttbarrier := NewSoftTimeTickBarrier(context.TODO(), ttStream, peerIds, minTtInterval) - require.NotNil(t, sttbarrier) - - sttbarrier.Start() - defer sttbarrier.Close() - - // Make sure all msgs in outputStream is consumed - time.Sleep(100 * time.Millisecond) - - ts, err := sttbarrier.GetTimeTick() - assert.Nil(t, err) - assert.Equal(t, Timestamp(30), ts) -} - -func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) { - channels := []string{"SoftTtBarrierGetTimeTickClose"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - minTtInterval := Timestamp(10) - validPeerIds := []UniqueID{1, 2, 3} - - sttbarrier := NewSoftTimeTickBarrier(context.TODO(), ttStream, validPeerIds, minTtInterval) - require.NotNil(t, sttbarrier) - - sttbarrier.Start() - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - sttbarrier.Close() - }() - wg.Wait() - - ts, err := sttbarrier.GetTimeTick() - assert.NotNil(t, err) - assert.Equal(t, Timestamp(0), ts) - - // Receive empty msgPacks - channels01 := []string{"GetTimeTick01"} - ttmsgs01 := [][2]int{} - inStream01, ttStream01 := producer(channels01, ttmsgs01) - defer func() { - (*inStream01).Close() - (*ttStream01).Close() - }() - - minTtInterval = Timestamp(10) - validPeerIds = []UniqueID{1, 2, 3} - - sttbarrier01 := NewSoftTimeTickBarrier(context.TODO(), ttStream01, validPeerIds, minTtInterval) - require.NotNil(t, sttbarrier01) - sttbarrier01.Start() - - var wg1 sync.WaitGroup - wg1.Add(1) - - go func() { - defer wg1.Done() - sttbarrier01.Close() - }() - - ts, err = sttbarrier01.GetTimeTick() - assert.NotNil(t, err) - assert.Equal(t, Timestamp(0), ts) -} - -func TestTt_SoftTtBarrierGetTimeTickCancel(t *testing.T) { - channels := []string{"SoftTtBarrierGetTimeTickCancel"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - minTtInterval := Timestamp(10) - validPeerIds := []UniqueID{1, 2, 3} - - ctx, cancel := context.WithCancel(context.Background()) - sttbarrier := NewSoftTimeTickBarrier(ctx, ttStream, validPeerIds, minTtInterval) - require.NotNil(t, sttbarrier) - - sttbarrier.Start() - - go func() { - time.Sleep(10 * time.Millisecond) - cancel() - time.Sleep(10 * time.Millisecond) - sttbarrier.Close() - }() - - time.Sleep(20 * time.Millisecond) - - ts, err := sttbarrier.GetTimeTick() - assert.NotNil(t, err) - assert.Equal(t, Timestamp(0), ts) - log.Println(err) -} - -func TestTt_HardTtBarrierStart(t *testing.T) { - channels := []string{"HardTtBarrierStart"} - - ttmsgs := [][2]int{ - {1, 10}, - {2, 10}, - {3, 10}, - } - - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - peerIds := []UniqueID{1, 2, 3} - sttbarrier := NewHardTimeTickBarrier(context.TODO(), ttStream, peerIds) - require.NotNil(t, sttbarrier) - - sttbarrier.Start() - defer sttbarrier.Close() - - // Make sure all msgs in outputStream is consumed - time.Sleep(100 * time.Millisecond) - - ts, err := sttbarrier.GetTimeTick() - assert.Nil(t, err) - assert.Equal(t, Timestamp(10), ts) -} - -func TestTt_HardTtBarrierGetTimeTick(t *testing.T) { - - channels := []string{"HardTtBarrierGetTimeTick"} - - ttmsgs := [][2]int{ - {1, 10}, - {1, 20}, - {1, 30}, - {2, 10}, - {2, 20}, - {3, 10}, - {3, 20}, - } - - inStream, ttStream := producer(channels, ttmsgs) - defer func() { - (*inStream).Close() - (*ttStream).Close() - }() - - peerIds := []UniqueID{1, 2, 3} - sttbarrier := NewHardTimeTickBarrier(context.TODO(), ttStream, peerIds) - require.NotNil(t, sttbarrier) - - sttbarrier.Start() - defer sttbarrier.Close() - - // Make sure all msgs in outputStream is consumed - time.Sleep(100 * time.Millisecond) - - ts, err := sttbarrier.GetTimeTick() - assert.Nil(t, err) - assert.Equal(t, Timestamp(10), ts) - - ts, err = sttbarrier.GetTimeTick() - assert.Nil(t, err) - assert.Equal(t, Timestamp(20), ts) - - // ---------------------stuck-------------------------- - channelsStuck := []string{"HardTtBarrierGetTimeTickStuck"} - - ttmsgsStuck := [][2]int{ - {1, 10}, - {2, 10}, - } - - inStreamStuck, ttStreamStuck := producer(channelsStuck, ttmsgsStuck) - defer func() { - (*inStreamStuck).Close() - (*ttStreamStuck).Close() - }() - - peerIdsStuck := []UniqueID{1, 2, 3} - sttbarrierStuck := NewHardTimeTickBarrier(context.TODO(), ttStreamStuck, peerIdsStuck) - require.NotNil(t, sttbarrierStuck) - - sttbarrierStuck.Start() - go func() { - time.Sleep(1 * time.Second) - sttbarrierStuck.Close() - }() - - time.Sleep(100 * time.Millisecond) - - // This will stuck - ts, err = sttbarrierStuck.GetTimeTick() - - // ---------------------context cancel------------------------ - channelsCancel := []string{"HardTtBarrierGetTimeTickCancel"} - - ttmsgsCancel := [][2]int{ - {1, 10}, - {2, 10}, - } - - inStreamCancel, ttStreamCancel := producer(channelsCancel, ttmsgsCancel) - defer func() { - (*inStreamCancel).Close() - (*ttStreamCancel).Close() - }() - - peerIdsCancel := []UniqueID{1, 2, 3} - - ctx, cancel := context.WithCancel(context.Background()) - sttbarrierCancel := NewHardTimeTickBarrier(ctx, ttStreamCancel, peerIdsCancel) - require.NotNil(t, sttbarrierCancel) - - sttbarrierCancel.Start() - go func() { - time.Sleep(1 * time.Second) - cancel() - }() - - time.Sleep(100 * time.Millisecond) - - // This will stuck - ts, err = sttbarrierCancel.GetTimeTick() - -} diff --git a/internal/master/timesync/timetick.go b/internal/master/timesync/timetick.go index 715f78f4e9..68ec760ca5 100644 --- a/internal/master/timesync/timetick.go +++ b/internal/master/timesync/timetick.go @@ -2,13 +2,10 @@ package timesync import "github.com/zilliztech/milvus-distributed/internal/util/typeutil" -type ( - UniqueID = typeutil.UniqueID - Timestamp = typeutil.Timestamp -) +type UniqueID = typeutil.UniqueID +type Timestamp = typeutil.Timestamp type TimeTickBarrier interface { - GetTimeTick() (Timestamp, error) + GetTimeTick() (Timestamp,error) Start() error - Close() } diff --git a/internal/master/tso/global_allocator.go b/internal/master/tso/global_allocator.go index 921dd94ddd..d553911675 100644 --- a/internal/master/tso/global_allocator.go +++ b/internal/master/tso/global_allocator.go @@ -36,19 +36,8 @@ type GlobalTSOAllocator struct { tso *timestampOracle } -var allocator *GlobalTSOAllocator - -func Init() { - InitGlobalTsoAllocator("timestamp", tsoutil.NewTSOKVBase("tso")) -} - -func InitGlobalTsoAllocator(key string, base kv.KVBase) { - allocator = NewGlobalTSOAllocator(key, base) - allocator.Initialize() -} - // NewGlobalTSOAllocator creates a new global TSO allocator. -func NewGlobalTSOAllocator(key string, kvBase kv.KVBase) *GlobalTSOAllocator { +func NewGlobalTSOAllocator(key string, kvBase kv.KVBase) Allocator { var saveInterval time.Duration = 3 * time.Second return &GlobalTSOAllocator{ @@ -63,7 +52,7 @@ func NewGlobalTSOAllocator(key string, kvBase kv.KVBase) *GlobalTSOAllocator { // Initialize will initialize the created global TSO allocator. func (gta *GlobalTSOAllocator) Initialize() error { - return gta.tso.InitTimestamp() + return gta.tso.SyncTimestamp() } // UpdateTSO is used to update the TSO in memory and the time window in etcd. @@ -108,33 +97,7 @@ func (gta *GlobalTSOAllocator) GenerateTSO(count uint32) (uint64, error) { return 0, errors.New("can not get timestamp") } -func (gta *GlobalTSOAllocator) Alloc(count uint32) (typeutil.Timestamp, error) { - //return gta.tso.SyncTimestamp() - start, err := gta.GenerateTSO(count) - if err != nil { - return typeutil.ZeroTimestamp, err - } - //ret := make([]typeutil.Timestamp, count) - //for i:=uint32(0); i < count; i++{ - // ret[i] = start + uint64(i) - //} - return start, err -} - -func (gta *GlobalTSOAllocator) AllocOne() (typeutil.Timestamp, error) { - return gta.GenerateTSO(1) -} - // Reset is used to reset the TSO allocator. func (gta *GlobalTSOAllocator) Reset() { gta.tso.ResetTimestamp() } - -func AllocOne() (typeutil.Timestamp, error) { - return allocator.AllocOne() -} - -// Reset is used to reset the TSO allocator. -func Alloc(count uint32) (typeutil.Timestamp, error) { - return allocator.Alloc(count) -} diff --git a/internal/master/tso/global_allocator_test.go b/internal/master/tso/global_allocator_test.go index 3b4634ed1b..91d6459d7c 100644 --- a/internal/master/tso/global_allocator_test.go +++ b/internal/master/tso/global_allocator_test.go @@ -1,21 +1,18 @@ package tso import ( + "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/kv/mockkv" + "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" "os" "testing" "time" - - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/conf" - "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" ) var GTsoAllocator Allocator func TestMain(m *testing.M) { - conf.LoadConfig("config.yaml") - GTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase("tso")) - + GTsoAllocator = NewGlobalTSOAllocator("timestamp", mockkv.NewEtcdKV()) exitCode := m.Run() os.Exit(exitCode) } @@ -31,7 +28,7 @@ func TestGlobalTSOAllocator_GenerateTSO(t *testing.T) { startTs, err := GTsoAllocator.GenerateTSO(perCount) assert.Nil(t, err) lastPhysical, lastLogical := tsoutil.ParseTS(startTs) - for i := 0; i < count; i++ { + for i:=0;i < count; i++{ ts, _ := GTsoAllocator.GenerateTSO(perCount) physical, logical := tsoutil.ParseTS(ts) if lastPhysical == physical { @@ -44,7 +41,7 @@ func TestGlobalTSOAllocator_GenerateTSO(t *testing.T) { func TestGlobalTSOAllocator_SetTSO(t *testing.T) { curTime := time.Now() - nextTime := curTime.Add(2 * time.Second) + nextTime := curTime.Add(2 * time.Second ) physical := nextTime.UnixNano() / int64(time.Millisecond) logical := int64(0) err := GTsoAllocator.SetTSO(tsoutil.ComposeTS(physical, logical)) diff --git a/internal/master/tso/tso.go b/internal/master/tso/tso.go index f66709a02d..111e986983 100644 --- a/internal/master/tso/tso.go +++ b/internal/master/tso/tso.go @@ -46,8 +46,8 @@ type atomicObject struct { // timestampOracle is used to maintain the logic of tso. type timestampOracle struct { - key string - kvBase kv.KVBase + key string + kvBase kv.KVBase // TODO: remove saveInterval saveInterval time.Duration @@ -83,27 +83,28 @@ func (t *timestampOracle) saveTimestamp(ts time.Time) error { return nil } -func (t *timestampOracle) InitTimestamp() error { +// SyncTimestamp is used to synchronize the timestamp. +func (t *timestampOracle) SyncTimestamp() error { - //last, err := t.loadTimestamp() - //if err != nil { - // return err - //} + last, err := t.loadTimestamp() + if err != nil { + return err + } next := time.Now() // If the current system time minus the saved etcd timestamp is less than `updateTimestampGuard`, // the timestamp allocation will start from the saved etcd timestamp temporarily. - //if typeutil.SubTimeByWallClock(next, last) < updateTimestampGuard { - // next = last.Add(updateTimestampGuard) - //} + if typeutil.SubTimeByWallClock(next, last) < updateTimestampGuard { + next = last.Add(updateTimestampGuard) + } save := next.Add(t.saveInterval) - if err := t.saveTimestamp(save); err != nil { + if err = t.saveTimestamp(save); err != nil { return err } - //log.Print("sync and save timestamp", zap.Time("last", last), zap.Time("save", save), zap.Time("next", next)) + log.Print("sync and save timestamp", zap.Time("last", last), zap.Time("save", save), zap.Time("next", next)) current := &atomicObject{ physical: next, @@ -155,7 +156,7 @@ func (t *timestampOracle) UpdateTimestamp() error { now := time.Now() jetLag := typeutil.SubTimeByWallClock(now, prev.physical) - if jetLag > 3*UpdateTimestampStep { + if jetLag > 3 * UpdateTimestampStep { log.Print("clock offset", zap.Duration("jet-lag", jetLag), zap.Time("prev-physical", prev.physical), zap.Time("now", now)) } @@ -196,7 +197,7 @@ func (t *timestampOracle) UpdateTimestamp() error { // ResetTimestamp is used to reset the timestamp. func (t *timestampOracle) ResetTimestamp() { zero := &atomicObject{ - physical: time.Now(), + physical: typeutil.ZeroTime, } atomic.StorePointer(&t.TSO, unsafe.Pointer(zero)) } diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index cd5814f69a..c964a709b3 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -2,12 +2,10 @@ package msgstream import ( "context" - "github.com/zilliztech/milvus-distributed/internal/errors" "log" "sync" "github.com/golang/protobuf/proto" - commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/apache/pulsar-client-go/pulsar" @@ -24,7 +22,7 @@ type MsgPack struct { Msgs []*TsMsg } -type RepackFunc func(msgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) +type RepackFunc func(msgs []*TsMsg, hashKeys [][]int32) map[int32]*MsgPack type MsgStream interface { Start() @@ -44,18 +42,18 @@ type PulsarMsgStream struct { repackFunc RepackFunc unmarshal *UnmarshalDispatcher receiveBuf chan *MsgPack + receiveBufSize int64 wait *sync.WaitGroup streamCancel func() } func NewPulsarMsgStream(ctx context.Context, receiveBufSize int64) *PulsarMsgStream { streamCtx, streamCancel := context.WithCancel(ctx) - stream := &PulsarMsgStream{ + return &PulsarMsgStream{ ctx: streamCtx, streamCancel: streamCancel, + receiveBufSize: receiveBufSize, } - stream.receiveBuf = make(chan *MsgPack, receiveBufSize) - return stream } func (ms *PulsarMsgStream) SetPulsarCient(address string) { @@ -147,23 +145,22 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { } var result map[int32]*MsgPack - var err error if ms.repackFunc != nil { - result, err = ms.repackFunc(tsMsgs, reBucketValues) + result = ms.repackFunc(tsMsgs, reBucketValues) } else { - msgType := (*tsMsgs[0]).Type() - switch msgType { - case internalPb.MsgType_kInsert: - result, err = insertRepackFunc(tsMsgs, reBucketValues) - case internalPb.MsgType_kDelete: - result, err = deleteRepackFunc(tsMsgs, reBucketValues) - default: - result, err = defaultRepackFunc(tsMsgs, reBucketValues) + result = make(map[int32]*MsgPack) + for i, request := range tsMsgs { + keys := reBucketValues[i] + for _, channelID := range keys { + _, ok := result[channelID] + if !ok { + msgPack := MsgPack{} + result[channelID] = &msgPack + } + result[channelID].Msgs = append(result[channelID].Msgs, request) + } } } - if err != nil { - return err - } for k, v := range result { for i := 0; i < len(v.Msgs); i++ { mb, err := (*v.Msgs[i]).Marshal(v.Msgs[i]) @@ -218,6 +215,7 @@ func (ms *PulsarMsgStream) Consume() *MsgPack { func (ms *PulsarMsgStream) bufMsgPackToChannel() { defer ms.wait.Done() + ms.receiveBuf = make(chan *MsgPack, ms.receiveBufSize) for { select { case <-ms.ctx.Done(): @@ -273,8 +271,8 @@ func NewPulsarTtMsgStream(ctx context.Context, receiveBufSize int64) *PulsarTtMs pulsarMsgStream := PulsarMsgStream{ ctx: streamCtx, streamCancel: streamCancel, + receiveBufSize: receiveBufSize, } - pulsarMsgStream.receiveBuf = make(chan *MsgPack, receiveBufSize) return &PulsarTtMsgStream{ PulsarMsgStream: pulsarMsgStream, } @@ -290,6 +288,7 @@ func (ms *PulsarTtMsgStream) Start() { func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { defer ms.wait.Done() + ms.receiveBuf = make(chan *MsgPack, ms.receiveBufSize) ms.unsolvedBuf = make([]*TsMsg, 0) ms.inputBuf = make([]*TsMsg, 0) for { @@ -384,113 +383,3 @@ func checkTimeTickMsg(msg map[int]Timestamp) (Timestamp, bool) { } return 0, false } - -func insertRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - if (*request).Type() != internalPb.MsgType_kInsert { - return nil, errors.New(string("msg's must be Insert")) - } - insertRequest := (*request).(*InsertMsg) - keys := hashKeys[i] - - timestampLen := len(insertRequest.Timestamps) - rowIDLen := len(insertRequest.RowIds) - rowDataLen := len(insertRequest.RowData) - keysLen := len(keys) - - if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { - return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal")) - } - for index, key := range keys { - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - - sliceRequest := internalPb.InsertRequest{ - MsgType: internalPb.MsgType_kInsert, - ReqId: insertRequest.ReqId, - CollectionName: insertRequest.CollectionName, - PartitionTag: insertRequest.PartitionTag, - SegmentId: insertRequest.SegmentId, - ChannelId: insertRequest.ChannelId, - ProxyId: insertRequest.ProxyId, - Timestamps: []uint64{insertRequest.Timestamps[index]}, - RowIds: []int64{insertRequest.RowIds[index]}, - RowData: []*commonPb.Blob{insertRequest.RowData[index]}, - } - - var msg TsMsg = &InsertMsg{ - InsertRequest: sliceRequest, - } - - result[key].Msgs = append(result[key].Msgs, &msg) - } - } - return result, nil -} - -func deleteRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - if (*request).Type() != internalPb.MsgType_kDelete { - return nil, errors.New(string("msg's must be Delete")) - } - deleteRequest := (*request).(*DeleteMsg) - keys := hashKeys[i] - - timestampLen := len(deleteRequest.Timestamps) - primaryKeysLen := len(deleteRequest.PrimaryKeys) - keysLen := len(keys) - - if keysLen != timestampLen || keysLen != primaryKeysLen { - return nil, errors.New(string("the length of hashValue, timestamps, primaryKeys are not equal")) - } - - for index, key := range keys { - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - - sliceRequest := internalPb.DeleteRequest{ - MsgType: internalPb.MsgType_kDelete, - ReqId: deleteRequest.ReqId, - CollectionName: deleteRequest.CollectionName, - ChannelId: deleteRequest.ChannelId, - ProxyId: deleteRequest.ProxyId, - Timestamps: []uint64{deleteRequest.Timestamps[index]}, - PrimaryKeys: []int64{deleteRequest.PrimaryKeys[index]}, - } - - var msg TsMsg = &DeleteMsg{ - DeleteRequest: sliceRequest, - } - - result[key].Msgs = append(result[key].Msgs, &msg) - } - } - return result, nil -} - -func defaultRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - keys := hashKeys[i] - if len(keys) != 1 { - return nil, errors.New(string("len(msg.hashValue) must equal 1")) - } - key := keys[0] - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - - result[key].Msgs = append(result[key].Msgs, request) - } - return result, nil -} diff --git a/internal/msgstream/msgstream_test.go b/internal/msgstream/msgstream_test.go index c168961cc5..a0060583be 100644 --- a/internal/msgstream/msgstream_test.go +++ b/internal/msgstream/msgstream_test.go @@ -3,14 +3,13 @@ package msgstream import ( "context" "fmt" - "log" "testing" commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" ) -func repackFunc(msgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { +func repackFunc(msgs []*TsMsg, hashKeys [][]int32) map[int32]*MsgPack { result := make(map[int32]*MsgPack) for i, request := range msgs { keys := hashKeys[i] @@ -23,7 +22,7 @@ func repackFunc(msgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { result[channelID].Msgs = append(result[channelID].Msgs, request) } } - return result, nil + return result } func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) *TsMsg { @@ -44,8 +43,6 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) *TsMsg { ChannelId: 1, ProxyId: 1, Timestamps: []Timestamp{1}, - RowIds: []int64{1}, - RowData: []*commonPb.Blob{{}}, } insertMsg := &InsertMsg{ BaseMsg: baseMsg, @@ -212,11 +209,7 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kInsert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } - + (*inputStream).Produce(&msgPack) receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -234,10 +227,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kDelete, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + (*inputStream).Produce(&msgPack) receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -254,10 +244,7 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kSearch, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + (*inputStream).Produce(&msgPack) receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -274,10 +261,7 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kSearchResult, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + (*inputStream).Produce(&msgPack) receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -294,10 +278,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kTimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + (*inputStream).Produce(&msgPack) receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -314,10 +295,7 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kTimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Broadcast(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + (*inputStream).Broadcast(&msgPack) receiveMsg(outputStream, len(consumerChannels)*len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() @@ -334,164 +312,12 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(internalPb.MsgType_kInsert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName, repackFunc) - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + (*inputStream).Produce(&msgPack) receiveMsg(outputStream, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() } -func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { - pulsarAddress := "pulsar://localhost:6650" - producerChannels := []string{"insert1", "insert2"} - consumerChannels := []string{"insert1", "insert2"} - consumerSubName := "subInsert" - - baseMsg := BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []int32{1, 3}, - } - - insertRequest := internalPb.InsertRequest{ - MsgType: internalPb.MsgType_kInsert, - ReqId: 1, - CollectionName: "Collection", - PartitionTag: "Partition", - SegmentId: 1, - ChannelId: 1, - ProxyId: 1, - Timestamps: []Timestamp{1, 1}, - RowIds: []int64{1, 3}, - RowData: []*commonPb.Blob{{}, {}}, - } - insertMsg := &InsertMsg{ - BaseMsg: baseMsg, - InsertRequest: insertRequest, - } - var tsMsg TsMsg = insertMsg - msgPack := MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, &tsMsg) - - inputStream := NewPulsarMsgStream(context.Background(), 100) - inputStream.SetPulsarCient(pulsarAddress) - inputStream.CreatePulsarProducers(producerChannels) - inputStream.Start() - - outputStream := NewPulsarMsgStream(context.Background(), 100) - outputStream.SetPulsarCient(pulsarAddress) - unmarshalDispatcher := NewUnmarshalDispatcher() - outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) - outputStream.Start() - var output MsgStream = outputStream - - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } - receiveMsg(&output, len(msgPack.Msgs)*2) - (*inputStream).Close() - (*outputStream).Close() -} - -func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { - pulsarAddress := "pulsar://localhost:6650" - producerChannels := []string{"insert1", "insert2"} - consumerChannels := []string{"insert1", "insert2"} - consumerSubName := "subInsert" - - baseMsg := BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []int32{1, 3}, - } - - deleteRequest := internalPb.DeleteRequest{ - MsgType: internalPb.MsgType_kDelete, - ReqId: 1, - CollectionName: "Collection", - ChannelId: 1, - ProxyId: 1, - Timestamps: []Timestamp{1, 1}, - PrimaryKeys: []int64{1, 3}, - } - deleteMsg := &DeleteMsg{ - BaseMsg: baseMsg, - DeleteRequest: deleteRequest, - } - var tsMsg TsMsg = deleteMsg - msgPack := MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, &tsMsg) - - inputStream := NewPulsarMsgStream(context.Background(), 100) - inputStream.SetPulsarCient(pulsarAddress) - inputStream.CreatePulsarProducers(producerChannels) - inputStream.Start() - - outputStream := NewPulsarMsgStream(context.Background(), 100) - outputStream.SetPulsarCient(pulsarAddress) - unmarshalDispatcher := NewUnmarshalDispatcher() - outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) - outputStream.Start() - var output MsgStream = outputStream - - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } - receiveMsg(&output, len(msgPack.Msgs)*2) - (*inputStream).Close() - (*outputStream).Close() -} - -func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { - pulsarAddress := "pulsar://localhost:6650" - producerChannels := []string{"insert1", "insert2"} - consumerChannels := []string{"insert1", "insert2"} - consumerSubName := "subInsert" - - baseMsg := BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []int32{1}, - } - - timeTickRequest := internalPb.TimeTickMsg{ - MsgType: internalPb.MsgType_kTimeTick, - PeerId: int64(1), - Timestamp: uint64(1), - } - timeTick := &TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickRequest, - } - var tsMsg TsMsg = timeTick - msgPack := MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, &tsMsg) - - inputStream := NewPulsarMsgStream(context.Background(), 100) - inputStream.SetPulsarCient(pulsarAddress) - inputStream.CreatePulsarProducers(producerChannels) - inputStream.Start() - - outputStream := NewPulsarMsgStream(context.Background(), 100) - outputStream.SetPulsarCient(pulsarAddress) - unmarshalDispatcher := NewUnmarshalDispatcher() - outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) - outputStream.Start() - var output MsgStream = outputStream - - err := (*inputStream).Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } - receiveMsg(&output, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() -} - func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { pulsarAddress := "pulsar://localhost:6650" producerChannels := []string{"insert1", "insert2"} @@ -509,18 +335,9 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Broadcast(&msgPack0) - if err != nil { - log.Fatalf("broadcast error = %v", err) - } - err = (*inputStream).Produce(&msgPack1) - if err != nil { - log.Fatalf("produce error = %v", err) - } - err = (*inputStream).Broadcast(&msgPack2) - if err != nil { - log.Fatalf("broadcast error = %v", err) - } + (*inputStream).Broadcast(&msgPack0) + (*inputStream).Produce(&msgPack1) + (*inputStream).Broadcast(&msgPack2) receiveMsg(outputStream, len(msgPack1.Msgs)) outputTtStream := (*outputStream).(*PulsarTtMsgStream) fmt.Printf("timestamp = %v", outputTtStream.lastTimeStamp) diff --git a/internal/msgstream/task_test.go b/internal/msgstream/task_test.go index 3c1dc426c3..4755adef8e 100644 --- a/internal/msgstream/task_test.go +++ b/internal/msgstream/task_test.go @@ -2,10 +2,7 @@ package msgstream import ( "context" - "errors" "fmt" - commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "log" "testing" "github.com/golang/protobuf/proto" @@ -40,53 +37,6 @@ func (tt *InsertTask) Unmarshal(input []byte) (*TsMsg, error) { return &tsMsg, nil } -func newRepackFunc(tsMsgs []*TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - if (*request).Type() != internalPb.MsgType_kInsert { - return nil, errors.New(string("msg's must be Insert")) - } - insertRequest := (*request).(*InsertTask).InsertRequest - keys := hashKeys[i] - - timestampLen := len(insertRequest.Timestamps) - rowIDLen := len(insertRequest.RowIds) - rowDataLen := len(insertRequest.RowData) - keysLen := len(keys) - - if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { - return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal")) - } - for index, key := range keys { - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - - sliceRequest := internalPb.InsertRequest{ - MsgType: internalPb.MsgType_kInsert, - ReqId: insertRequest.ReqId, - CollectionName: insertRequest.CollectionName, - PartitionTag: insertRequest.PartitionTag, - SegmentId: insertRequest.SegmentId, - ChannelId: insertRequest.ChannelId, - ProxyId: insertRequest.ProxyId, - Timestamps: []uint64{insertRequest.Timestamps[index]}, - RowIds: []int64{insertRequest.RowIds[index]}, - RowData: []*commonPb.Blob{insertRequest.RowData[index]}, - } - - var msg TsMsg = &InsertTask{ - InsertMsg: InsertMsg{InsertRequest: sliceRequest}, - } - - result[key].Msgs = append(result[key].Msgs, &msg) - } - } - return result, nil -} - func getMsg(reqID UniqueID, hashValue int32) *TsMsg { var tsMsg TsMsg baseMsg := BaseMsg{ @@ -103,8 +53,6 @@ func getMsg(reqID UniqueID, hashValue int32) *TsMsg { ChannelId: 1, ProxyId: 1, Timestamps: []Timestamp{1}, - RowIds: []int64{1}, - RowData: []*commonPb.Blob{{}}, } insertMsg := InsertMsg{ BaseMsg: baseMsg, @@ -131,7 +79,6 @@ func TestStream_task_Insert(t *testing.T) { inputStream := NewPulsarMsgStream(context.Background(), 100) inputStream.SetPulsarCient(pulsarAddress) inputStream.CreatePulsarProducers(producerChannels) - inputStream.SetRepackFunc(newRepackFunc) inputStream.Start() outputStream := NewPulsarMsgStream(context.Background(), 100) @@ -142,10 +89,7 @@ func TestStream_task_Insert(t *testing.T) { outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) outputStream.Start() - err := inputStream.Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + inputStream.Produce(&msgPack) receiveCount := 0 for { result := (*outputStream).Consume() diff --git a/internal/msgstream/unmarshal_test.go b/internal/msgstream/unmarshal_test.go index 24812eb520..0c12faa718 100644 --- a/internal/msgstream/unmarshal_test.go +++ b/internal/msgstream/unmarshal_test.go @@ -3,7 +3,6 @@ package msgstream import ( "context" "fmt" - "log" "testing" "github.com/golang/protobuf/proto" @@ -48,10 +47,7 @@ func TestStream_unmarshal_Insert(t *testing.T) { outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) outputStream.Start() - err := inputStream.Produce(&msgPack) - if err != nil { - log.Fatalf("produce error = %v", err) - } + inputStream.Produce(&msgPack) receiveCount := 0 for { result := (*outputStream).Consume() diff --git a/internal/proxy/grpc_service.go b/internal/proxy/grpc_service.go index 0b99b66690..877136a72b 100644 --- a/internal/proxy/grpc_service.go +++ b/internal/proxy/grpc_service.go @@ -3,7 +3,7 @@ package proxy import ( "context" "errors" - "github.com/golang/protobuf/proto" + "github.com/gogo/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/msgstream" "log" @@ -35,18 +35,21 @@ func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb. defer it.cancel() - p.taskSch.DmQueue.Enqueue(it) - select { - case <-ctx.Done(): - log.Print("insert timeout!") - return &servicepb.IntegerRangeResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "insert timeout!", - }, - }, errors.New("insert timeout!") - case result := <-it.resultChan: - return result, nil + var t task = it + p.taskSch.DmQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("insert timeout!") + return &servicepb.IntegerRangeResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "insert timeout!", + }, + }, errors.New("insert timeout!") + case result := <-it.resultChan: + return result, nil + } } } @@ -66,16 +69,19 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc cct.ctx, cct.cancel = context.WithCancel(ctx) defer cct.cancel() - p.taskSch.DdQueue.Enqueue(cct) - select { - case <-ctx.Done(): - log.Print("create collection timeout!") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "create collection timeout!", - }, errors.New("create collection timeout!") - case result := <-cct.resultChan: - return result, nil + var t task = cct + p.taskSch.DdQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("create collection timeout!") + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "create collection timeout!", + }, errors.New("create collection timeout!") + case result := <-cct.resultChan: + return result, nil + } } } @@ -96,18 +102,21 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu qt.SearchRequest.Query.Value = queryBytes defer qt.cancel() - p.taskSch.DqQueue.Enqueue(qt) - select { - case <-ctx.Done(): - log.Print("query timeout!") - return &servicepb.QueryResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "query timeout!", - }, - }, errors.New("query timeout!") - case result := <-qt.resultChan: - return result, nil + var t task = qt + p.taskSch.DqQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("query timeout!") + return &servicepb.QueryResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "query timeout!", + }, + }, errors.New("query timeout!") + case result := <-qt.resultChan: + return result, nil + } } } @@ -125,16 +134,19 @@ func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionNam dct.ctx, dct.cancel = context.WithCancel(ctx) defer dct.cancel() - p.taskSch.DdQueue.Enqueue(dct) - select { - case <-ctx.Done(): - log.Print("create collection timeout!") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "create collection timeout!", - }, errors.New("create collection timeout!") - case result := <-dct.resultChan: - return result, nil + var t task = dct + p.taskSch.DdQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("create collection timeout!") + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "create collection timeout!", + }, errors.New("create collection timeout!") + case result := <-dct.resultChan: + return result, nil + } } } @@ -152,19 +164,22 @@ func (p *Proxy) HasCollection(ctx context.Context, req *servicepb.CollectionName hct.ctx, hct.cancel = context.WithCancel(ctx) defer hct.cancel() - p.taskSch.DqQueue.Enqueue(hct) - select { - case <-ctx.Done(): - log.Print("has collection timeout!") - return &servicepb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "has collection timeout!", - }, - Value: false, - }, errors.New("has collection timeout!") - case result := <-hct.resultChan: - return result, nil + var t task = hct + p.taskSch.DqQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("has collection timeout!") + return &servicepb.BoolResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "has collection timeout!", + }, + Value: false, + }, errors.New("has collection timeout!") + case result := <-hct.resultChan: + return result, nil + } } } @@ -182,18 +197,21 @@ func (p *Proxy) DescribeCollection(ctx context.Context, req *servicepb.Collectio dct.ctx, dct.cancel = context.WithCancel(ctx) defer dct.cancel() - p.taskSch.DqQueue.Enqueue(dct) - select { - case <-ctx.Done(): - log.Print("has collection timeout!") - return &servicepb.CollectionDescription{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "describe collection timeout!", - }, - }, errors.New("describe collection timeout!") - case result := <-dct.resultChan: - return result, nil + var t task = dct + p.taskSch.DqQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("has collection timeout!") + return &servicepb.CollectionDescription{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "describe collection timeout!", + }, + }, errors.New("describe collection timeout!") + case result := <-dct.resultChan: + return result, nil + } } } @@ -210,18 +228,21 @@ func (p *Proxy) ShowCollections(ctx context.Context, req *commonpb.Empty) (*serv sct.ctx, sct.cancel = context.WithCancel(ctx) defer sct.cancel() - p.taskSch.DqQueue.Enqueue(sct) - select { - case <-ctx.Done(): - log.Print("show collections timeout!") - return &servicepb.StringListResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: "show collections timeout!", - }, - }, errors.New("show collections timeout!") - case result := <-sct.resultChan: - return result, nil + var t task = sct + p.taskSch.DqQueue.Enqueue(&t) + for { + select { + case <-ctx.Done(): + log.Print("show collections timeout!") + return &servicepb.StringListResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "show collections timeout!", + }, + }, errors.New("show collections timeout!") + case result := <-sct.resultChan: + return result, nil + } } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 9acc93a330..7ede1b8d2b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "google.golang.org/grpc" "log" "math/rand" "net" @@ -15,6 +14,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + "google.golang.org/grpc" ) type UniqueID = typeutil.UniqueID @@ -157,7 +157,7 @@ func (p *Proxy) queryResultLoop() { if len(queryResultBuf[reqId]) == 4 { // TODO: use the number of query node instead t := p.taskSch.getTaskByReqId(reqId) - qt := t.(*QueryTask) + qt := (*t).(*QueryTask) qt.resultBuf <- queryResultBuf[reqId] delete(queryResultBuf, reqId) } diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index dcf1335adb..361733468c 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -11,7 +11,7 @@ import ( type BaseTaskQueue struct { unissuedTasks *list.List - activeTasks map[Timestamp]task + activeTasks map[Timestamp]*task utLock sync.Mutex atLock sync.Mutex } @@ -24,23 +24,23 @@ func (queue *BaseTaskQueue) Empty() bool { return queue.unissuedTasks.Len() <= 0 && len(queue.activeTasks) <= 0 } -func (queue *BaseTaskQueue) AddUnissuedTask(t task) { +func (queue *BaseTaskQueue) AddUnissuedTask(t *task) { queue.utLock.Lock() defer queue.utLock.Unlock() queue.unissuedTasks.PushBack(t) } -func (queue *BaseTaskQueue) FrontUnissuedTask() task { +func (queue *BaseTaskQueue) FrontUnissuedTask() *task { queue.utLock.Lock() defer queue.utLock.Unlock() if queue.unissuedTasks.Len() <= 0 { log.Fatal("sorry, but the unissued task list is empty!") return nil } - return queue.unissuedTasks.Front().Value.(task) + return queue.unissuedTasks.Front().Value.(*task) } -func (queue *BaseTaskQueue) PopUnissuedTask() task { +func (queue *BaseTaskQueue) PopUnissuedTask() *task { queue.utLock.Lock() defer queue.utLock.Unlock() if queue.unissuedTasks.Len() <= 0 { @@ -48,13 +48,13 @@ func (queue *BaseTaskQueue) PopUnissuedTask() task { return nil } ft := queue.unissuedTasks.Front() - return queue.unissuedTasks.Remove(ft).(task) + return queue.unissuedTasks.Remove(ft).(*task) } -func (queue *BaseTaskQueue) AddActiveTask(t task) { +func (queue *BaseTaskQueue) AddActiveTask(t *task) { queue.atLock.Lock() defer queue.atLock.Lock() - ts := t.EndTs() + ts := (*t).EndTs() _, ok := queue.activeTasks[ts] if ok { log.Fatalf("task with timestamp %v already in active task list!", ts) @@ -62,7 +62,7 @@ func (queue *BaseTaskQueue) AddActiveTask(t task) { queue.activeTasks[ts] = t } -func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) task { +func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) *task { queue.atLock.Lock() defer queue.atLock.Lock() t, ok := queue.activeTasks[ts] @@ -74,19 +74,19 @@ func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) task { return nil } -func (queue *BaseTaskQueue) getTaskByReqId(reqId UniqueID) task { +func (queue *BaseTaskQueue) getTaskByReqId(reqId UniqueID) *task { queue.utLock.Lock() defer queue.utLock.Lock() for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() { - if e.Value.(task).Id() == reqId { - return e.Value.(task) + if (*(e.Value.(*task))).Id() == reqId { + return e.Value.(*task) } } queue.atLock.Lock() defer queue.atLock.Unlock() for ats := range queue.activeTasks { - if queue.activeTasks[ats].Id() == reqId { + if (*(queue.activeTasks[ats])).Id() == reqId { return queue.activeTasks[ats] } } @@ -98,7 +98,7 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool { queue.utLock.Lock() defer queue.utLock.Unlock() for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() { - if e.Value.(task).EndTs() >= ts { + if (*(e.Value.(*task))).EndTs() >= ts { return false } } @@ -114,20 +114,20 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool { return true } -type DdTaskQueue struct { +type ddTaskQueue struct { BaseTaskQueue lock sync.Mutex } -type DmTaskQueue struct { +type dmTaskQueue struct { BaseTaskQueue } -type DqTaskQueue struct { +type dqTaskQueue struct { BaseTaskQueue } -func (queue *DdTaskQueue) Enqueue(t task) error { +func (queue *ddTaskQueue) Enqueue(t *task) error { queue.lock.Lock() defer queue.lock.Unlock() // TODO: set Ts, ReqId, ProxyId @@ -135,49 +135,22 @@ func (queue *DdTaskQueue) Enqueue(t task) error { return nil } -func (queue *DmTaskQueue) Enqueue(t task) error { +func (queue *dmTaskQueue) Enqueue(t *task) error { // TODO: set Ts, ReqId, ProxyId queue.AddUnissuedTask(t) return nil } -func (queue *DqTaskQueue) Enqueue(t task) error { +func (queue *dqTaskQueue) Enqueue(t *task) error { // TODO: set Ts, ReqId, ProxyId queue.AddUnissuedTask(t) return nil } -func NewDdTaskQueue() *DdTaskQueue { - return &DdTaskQueue{ - BaseTaskQueue: BaseTaskQueue{ - unissuedTasks: list.New(), - activeTasks: make(map[Timestamp]task), - }, - } -} - -func NewDmTaskQueue() *DmTaskQueue { - return &DmTaskQueue{ - BaseTaskQueue: BaseTaskQueue{ - unissuedTasks: list.New(), - activeTasks: make(map[Timestamp]task), - }, - } -} - -func NewDqTaskQueue() *DqTaskQueue { - return &DqTaskQueue{ - BaseTaskQueue: BaseTaskQueue{ - unissuedTasks: list.New(), - activeTasks: make(map[Timestamp]task), - }, - } -} - type TaskScheduler struct { - DdQueue *DdTaskQueue - DmQueue *DmTaskQueue - DqQueue *DqTaskQueue + DdQueue *ddTaskQueue + DmQueue *dmTaskQueue + DqQueue *dqTaskQueue idAllocator *allocator.IdAllocator tsoAllocator *allocator.TimestampAllocator @@ -192,9 +165,6 @@ func NewTaskScheduler(ctx context.Context, tsoAllocator *allocator.TimestampAllocator) (*TaskScheduler, error) { ctx1, cancel := context.WithCancel(ctx) s := &TaskScheduler{ - DdQueue: NewDdTaskQueue(), - DmQueue: NewDmTaskQueue(), - DqQueue: NewDqTaskQueue(), idAllocator: idAllocator, tsoAllocator: tsoAllocator, ctx: ctx1, @@ -204,19 +174,19 @@ func NewTaskScheduler(ctx context.Context, return s, nil } -func (sched *TaskScheduler) scheduleDdTask() task { +func (sched *TaskScheduler) scheduleDdTask() *task { return sched.DdQueue.PopUnissuedTask() } -func (sched *TaskScheduler) scheduleDmTask() task { +func (sched *TaskScheduler) scheduleDmTask() *task { return sched.DmQueue.PopUnissuedTask() } -func (sched *TaskScheduler) scheduleDqTask() task { +func (sched *TaskScheduler) scheduleDqTask() *task { return sched.DqQueue.PopUnissuedTask() } -func (sched *TaskScheduler) getTaskByReqId(reqId UniqueID) task { +func (sched *TaskScheduler) getTaskByReqId(reqId UniqueID) *task { if t := sched.DdQueue.getTaskByReqId(reqId); t != nil { return t } @@ -241,22 +211,22 @@ func (sched *TaskScheduler) definitionLoop() { //sched.DdQueue.atLock.Lock() t := sched.scheduleDdTask() - err := t.PreExecute() + err := (*t).PreExecute() if err != nil { return } - err = t.Execute() + err = (*t).Execute() if err != nil { log.Printf("execute definition task failed, error = %v", err) } - t.Notify(err) + (*t).Notify(err) sched.DdQueue.AddActiveTask(t) - t.WaitToFinish() - t.PostExecute() + (*t).WaitToFinish() + (*t).PostExecute() - sched.DdQueue.PopActiveTask(t.EndTs()) + sched.DdQueue.PopActiveTask((*t).EndTs()) } } @@ -272,27 +242,27 @@ func (sched *TaskScheduler) manipulationLoop() { sched.DmQueue.atLock.Lock() t := sched.scheduleDmTask() - if err := t.PreExecute(); err != nil { + if err := (*t).PreExecute(); err != nil { return } go func() { - err := t.Execute() + err := (*t).Execute() if err != nil { log.Printf("execute manipulation task failed, error = %v", err) } - t.Notify(err) + (*t).Notify(err) }() sched.DmQueue.AddActiveTask(t) sched.DmQueue.atLock.Unlock() go func() { - t.WaitToFinish() - t.PostExecute() + (*t).WaitToFinish() + (*t).PostExecute() // remove from active list - sched.DmQueue.PopActiveTask(t.EndTs()) + sched.DmQueue.PopActiveTask((*t).EndTs()) }() } } @@ -309,27 +279,27 @@ func (sched *TaskScheduler) queryLoop() { sched.DqQueue.atLock.Lock() t := sched.scheduleDqTask() - if err := t.PreExecute(); err != nil { + if err := (*t).PreExecute(); err != nil { return } go func() { - err := t.Execute() + err := (*t).Execute() if err != nil { log.Printf("execute query task failed, error = %v", err) } - t.Notify(err) + (*t).Notify(err) }() sched.DqQueue.AddActiveTask(t) sched.DqQueue.atLock.Unlock() go func() { - t.WaitToFinish() - t.PostExecute() + (*t).WaitToFinish() + (*t).PostExecute() // remove from active list - sched.DqQueue.PopActiveTask(t.EndTs()) + sched.DqQueue.PopActiveTask((*t).EndTs()) }() } } diff --git a/internal/proxy/timetick.go b/internal/proxy/timetick.go index 3778c25053..6269940b28 100644 --- a/internal/proxy/timetick.go +++ b/internal/proxy/timetick.go @@ -51,6 +51,7 @@ func newTimeTick(ctx context.Context, tsoAllocator *allocator.TimestampAllocator return t } + func (tt *timeTick) tick() error { if tt.lastTick == tt.currentTick { diff --git a/internal/proxy/timetick_test.go b/internal/proxy/timetick_test.go index e159188c70..edaa4bd5b1 100644 --- a/internal/proxy/timetick_test.go +++ b/internal/proxy/timetick_test.go @@ -33,7 +33,7 @@ func TestTimeTick(t *testing.T) { tt := timeTick{ interval: 200, pulsarProducer: producer, - peerID: 1, + peerID: 1, ctx: ctx, areRequestsDelivered: func(ts Timestamp) bool { return true }, } diff --git a/internal/reader/col_seg_container.go b/internal/reader/col_seg_container.go index c71150bd57..bfb0f05b06 100644 --- a/internal/reader/col_seg_container.go +++ b/internal/reader/col_seg_container.go @@ -77,9 +77,9 @@ func (container *ColSegContainer) getCollectionByName(collectionName string) (*C } //----------------------------------------------------------------------------------------------------- partition -func (container *ColSegContainer) addPartition(collection *Collection, partitionTag string) (*Partition, error) { +func (container *ColSegContainer) addPartition(collection *Collection, partitionTag string) error { if collection == nil { - return nil, errors.New("null collection") + return errors.New("null collection") } var newPartition = newPartition(partitionTag) @@ -87,11 +87,11 @@ func (container *ColSegContainer) addPartition(collection *Collection, partition for _, col := range container.collections { if col.Name() == collection.Name() { *col.Partitions() = append(*col.Partitions(), newPartition) - return newPartition, nil + return nil } } - return nil, errors.New("cannot find collection, name = " + collection.Name()) + return errors.New("cannot find collection, name = " + collection.Name()) } func (container *ColSegContainer) removePartition(partition *Partition) error { @@ -138,13 +138,13 @@ func (container *ColSegContainer) getPartitionByTag(partitionTag string) (*Parti } //----------------------------------------------------------------------------------------------------- segment -func (container *ColSegContainer) addSegment(collection *Collection, partition *Partition, segmentID int64) (*Segment, error) { +func (container *ColSegContainer) addSegment(collection *Collection, partition *Partition, segmentID int64) error { if collection == nil { - return nil, errors.New("null collection") + return errors.New("null collection") } if partition == nil { - return nil, errors.New("null partition") + return errors.New("null partition") } var newSegment = newSegment(collection, segmentID) @@ -155,13 +155,13 @@ func (container *ColSegContainer) addSegment(collection *Collection, partition * for _, p := range *col.Partitions() { if p.Tag() == partition.Tag() { *p.Segments() = append(*p.Segments(), newSegment) - return newSegment, nil + return nil } } } } - return nil, errors.New("cannot find collection or segment") + return errors.New("cannot find collection or segment") } func (container *ColSegContainer) removeSegment(segment *Segment) error { diff --git a/internal/reader/col_seg_container_test.go b/internal/reader/col_seg_container_test.go deleted file mode 100644 index 014d43fc7d..0000000000 --- a/internal/reader/col_seg_container_test.go +++ /dev/null @@ -1,675 +0,0 @@ -package reader - -import ( - "context" - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" - "testing" -) - -//----------------------------------------------------------------------------------------------------- collection -func TestColSegContainer_addCollection(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) -} - -func TestColSegContainer_removeCollection(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - err := node.container.removeCollection(collection) - assert.NoError(t, err) - assert.Equal(t, len(node.container.collections), 0) -} - -func TestColSegContainer_getCollectionByID(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - targetCollection, err := node.container.getCollectionByID(UniqueID(0)) - assert.NoError(t, err) - assert.NotNil(t, targetCollection) - assert.Equal(t, targetCollection.meta.Schema.Name, "collection0") - assert.Equal(t, targetCollection.meta.Id, UniqueID(0)) -} - -func TestColSegContainer_getCollectionByName(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - targetCollection, err := node.container.getCollectionByName("collection0") - assert.NoError(t, err) - assert.NotNil(t, targetCollection) - assert.Equal(t, targetCollection.meta.Schema.Name, "collection0") - assert.Equal(t, targetCollection.meta.Id, UniqueID(0)) -} - -//----------------------------------------------------------------------------------------------------- partition -func TestColSegContainer_addPartition(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - for _, tag := range collectionMeta.PartitionTags { - targetPartition, err := node.container.addPartition(collection, tag) - assert.NoError(t, err) - assert.Equal(t, targetPartition.partitionTag, "default") - } -} - -func TestColSegContainer_removePartition(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - for _, tag := range collectionMeta.PartitionTags { - targetPartition, err := node.container.addPartition(collection, tag) - assert.NoError(t, err) - assert.Equal(t, targetPartition.partitionTag, "default") - err = node.container.removePartition(targetPartition) - assert.NoError(t, err) - } -} - -func TestColSegContainer_getPartitionByTag(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - for _, tag := range collectionMeta.PartitionTags { - targetPartition, err := node.container.addPartition(collection, tag) - assert.NoError(t, err) - assert.Equal(t, targetPartition.partitionTag, "default") - partition, err := node.container.getPartitionByTag(tag) - assert.NoError(t, err) - assert.NotNil(t, partition) - assert.Equal(t, partition.partitionTag, "default") - } -} - -//----------------------------------------------------------------------------------------------------- segment -func TestColSegContainer_addSegment(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) - assert.NoError(t, err) - - const segmentNum = 3 - for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - } -} - -func TestColSegContainer_removeSegment(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) - assert.NoError(t, err) - - const segmentNum = 3 - for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - err = node.container.removeSegment(targetSeg) - assert.NoError(t, err) - } -} - -func TestColSegContainer_getSegmentByID(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) - assert.NoError(t, err) - - const segmentNum = 3 - for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - seg, err := node.container.getSegmentByID(UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, seg.segmentID, UniqueID(i)) - } -} - -func TestColSegContainer_hasSegment(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) - assert.NoError(t, err) - - const segmentNum = 3 - for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - hasSeg := node.container.hasSegment(UniqueID(i)) - assert.Equal(t, hasSeg, true) - hasSeg = node.container.hasSegment(UniqueID(i + 100)) - assert.Equal(t, hasSeg, false) - } -} diff --git a/internal/reader/collection_test.go b/internal/reader/collection_test.go index e05964b867..24b7cbe2cb 100644 --- a/internal/reader/collection_test.go +++ b/internal/reader/collection_test.go @@ -1,165 +1,33 @@ package reader -import ( - "context" - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" - "testing" -) - -func TestCollection_Partitions(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - for _, tag := range collectionMeta.PartitionTags { - _, err := node.container.addPartition(collection, tag) - assert.NoError(t, err) - } - - partitions := collection.Partitions() - assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions)) -} - -func TestCollection_newCollection(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) -} - -func TestCollection_deleteCollection(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - deleteCollection(collection) -} +//func TestCollection_NewPartition(t *testing.T) { +// ctx := context.Background() +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, collection.CollectionID, int64(0)) +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, len(collection.Partitions), 1) +//} +// +//func TestCollection_DeletePartition(t *testing.T) { +// ctx := context.Background() +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, collection.CollectionID, int64(0)) +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, len(collection.Partitions), 1) +// +// collection.deletePartition(node, partition) +// +// assert.Equal(t, len(collection.Partitions), 0) +//} diff --git a/internal/reader/meta_service.go b/internal/reader/meta_service.go index a9f99bac8f..9a70ffc502 100644 --- a/internal/reader/meta_service.go +++ b/internal/reader/meta_service.go @@ -3,7 +3,7 @@ package reader import ( "context" "fmt" - "github.com/golang/protobuf/proto" + "github.com/gogo/protobuf/proto" "log" "path" "reflect" @@ -144,7 +144,7 @@ func (mService *metaService) processCollectionCreate(id string, value string) { if col != nil { newCollection := mService.container.addCollection(col, value) for _, partitionTag := range col.PartitionTags { - _, err := mService.container.addPartition(newCollection, partitionTag) + err := mService.container.addPartition(newCollection, partitionTag) if err != nil { log.Println(err) } @@ -174,7 +174,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) { return } if partition != nil { - _, err = mService.container.addSegment(col, partition, seg.SegmentId) + err = mService.container.addSegment(col, partition, seg.SegmentId) if err != nil { log.Println(err) return diff --git a/internal/reader/partition_test.go b/internal/reader/partition_test.go index 4e68962e6c..f9268cdecf 100644 --- a/internal/reader/partition_test.go +++ b/internal/reader/partition_test.go @@ -1,88 +1,57 @@ package reader -import ( - "context" - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" - "testing" -) - -func TestPartition_Segments(t *testing.T) { - ctx := context.Background() - pulsarUrl := "pulsar://localhost:6650" - node := NewQueryNode(ctx, 0, pulsarUrl) - - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - for _, tag := range collectionMeta.PartitionTags { - _, err := node.container.addPartition(collection, tag) - assert.NoError(t, err) - } - - partitions := collection.Partitions() - assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions)) - - targetPartition := (*partitions)[0] - - const segmentNum = 3 - for i:= 0; i < segmentNum; i++ { - _, err := node.container.addSegment(collection, targetPartition, UniqueID(i)) - assert.NoError(t, err) - } - - segments := targetPartition.Segments() - assert.Equal(t, segmentNum, len(*segments)) -} - -func TestPartition_newPartition(t *testing.T) { - partitionTag := "default" - partition := newPartition(partitionTag) - assert.Equal(t, partition.partitionTag, partitionTag) -} +//func TestPartition_NewSegment(t *testing.T) { +// ctx := context.Background() +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// +// var segment = partition.newSegment(0) +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, collection.CollectionID, int64(0)) +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, node.Collections[0].Partitions[0].Segments[0].SegmentID, int64(0)) +// +// assert.Equal(t, len(collection.Partitions), 1) +// assert.Equal(t, len(node.Collections), 1) +// assert.Equal(t, len(node.Collections[0].Partitions[0].Segments), 1) +// +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, node.foundSegmentBySegmentID(int64(0)), true) +//} +// +//func TestPartition_DeleteSegment(t *testing.T) { +// // 1. Construct node, collection, partition and segment +// ctx := context.Background() +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// +// var segment = partition.newSegment(0) +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, collection.CollectionID, int64(0)) +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, node.Collections[0].Partitions[0].Segments[0].SegmentID, int64(0)) +// +// assert.Equal(t, len(collection.Partitions), 1) +// assert.Equal(t, len(node.Collections), 1) +// assert.Equal(t, len(node.Collections[0].Partitions[0].Segments), 1) +// +// assert.Equal(t, segment.SegmentID, int64(0)) +// +// // 2. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// +// assert.Equal(t, len(collection.Partitions), 1) +// assert.Equal(t, len(node.Collections), 1) +// assert.Equal(t, len(node.Collections[0].Partitions[0].Segments), 0) +// assert.Equal(t, node.foundSegmentBySegmentID(int64(0)), false) +//} diff --git a/internal/reader/segment.go b/internal/reader/segment.go index 88965482ee..7c67fac8fb 100644 --- a/internal/reader/segment.go +++ b/internal/reader/segment.go @@ -84,7 +84,7 @@ func (s *Segment) getMemSize() int64 { return int64(memoryUsageInBytes) } -//-------------------------------------------------------------------------------------- preDm functions +//-------------------------------------------------------------------------------------- preprocess functions func (s *Segment) segmentPreInsert(numOfRecords int) int64 { /* long int diff --git a/internal/reader/segment_test.go b/internal/reader/segment_test.go index cb357bad07..a0baad00d3 100644 --- a/internal/reader/segment_test.go +++ b/internal/reader/segment_test.go @@ -1,541 +1,147 @@ package reader -import ( - "encoding/binary" - "github.com/golang/protobuf/proto" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" - "math" - "testing" - - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - - "github.com/stretchr/testify/assert" -) - -//-------------------------------------------------------------------------------------- constructor and destructor -func TestSegment_newSegment(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) -} - -func TestSegment_deleteSegment(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - deleteSegment(segment) -} - -//-------------------------------------------------------------------------------------- stats functions -func TestSegment_getRowCount(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offset = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) - - err := segment.segmentInsert(offset, &ids, ×tamps, &records) - assert.NoError(t, err) - - rowCount := segment.getRowCount() - assert.Equal(t, int64(N), rowCount) -} - -func TestSegment_getDeletedCount(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offsetInsert = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offsetInsert, int64(0)) - - var err = segment.segmentInsert(offsetInsert, &ids, ×tamps, &records) - assert.NoError(t, err) - - var offsetDelete = segment.segmentPreDelete(10) - assert.GreaterOrEqual(t, offsetDelete, int64(0)) - - err = segment.segmentDelete(offsetDelete, &ids, ×tamps) - assert.NoError(t, err) - - var deletedCount = segment.getDeletedCount() - // TODO: assert.Equal(t, deletedCount, len(ids)) - assert.Equal(t, deletedCount, int64(0)) -} - -func TestSegment_getMemSize(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offset = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) - - err := segment.segmentInsert(offset, &ids, ×tamps, &records) - assert.NoError(t, err) - - var memSize = segment.getMemSize() - assert.Equal(t, memSize, int64(2785280)) -} - -//-------------------------------------------------------------------------------------- dm & search functions -func TestSegment_segmentInsert(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offset = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) - - err := segment.segmentInsert(offset, &ids, ×tamps, &records) - assert.NoError(t, err) -} - -func TestSegment_segmentDelete(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offsetInsert = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offsetInsert, int64(0)) - - var err = segment.segmentInsert(offsetInsert, &ids, ×tamps, &records) - assert.NoError(t, err) - - var offsetDelete = segment.segmentPreDelete(10) - assert.GreaterOrEqual(t, offsetDelete, int64(0)) - - err = segment.segmentDelete(offsetDelete, &ids, ×tamps) - assert.NoError(t, err) -} - -//func TestSegment_segmentSearch(t *testing.T) { +//import ( +// "context" +// "encoding/binary" +// "fmt" +// "math" +// "testing" +// +// "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" +// +// "github.com/stretchr/testify/assert" +// msgPb "github.com/zilliztech/milvus-distributed/internal/proto/message" +//) +// +//func TestSegment_ConstructorAndDestructor(t *testing.T) { +// // 1. Construct node, collection, partition and segment +// ctx := context.Background() +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_SegmentInsert(t *testing.T) { +// // 1. Construct node, collection, partition and segment +// ctx := context.Background() +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Create ids and timestamps +// ids := []int64{1, 2, 3} +// timestamps := []uint64{0, 0, 0} +// +// // 3. Create records, use schema below: +// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); +// // schema_tmp->AddField("age", DataType::INT32); +// const DIM = 16 +// const N = 3 +// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +// var rawData []byte +// for _, ele := range vec { +// buf := make([]byte, 4) +// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) +// rawData = append(rawData, buf...) +// } +// bs := make([]byte, 4) +// binary.LittleEndian.PutUint32(bs, 1) +// rawData = append(rawData, bs...) +// var records []*commonpb.Blob +// for i := 0; i < N; i++ { +// blob := &commonpb.Blob{ +// Value: rawData, +// } +// records = append(records, blob) +// } +// +// // 4. Do PreInsert +// var offset = segment.segmentPreInsert(N) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 5. Do Insert +// var err = segment.segmentInsert(offset, &ids, ×tamps, &records) +// assert.NoError(t, err) +// +// // 6. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_SegmentDelete(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Create ids and timestamps +// ids := []int64{1, 2, 3} +// timestamps := []uint64{0, 0, 0} +// +// // 3. Do PreDelete +// var offset = segment.segmentPreDelete(10) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 4. Do Delete +// var err = segment.segmentDelete(offset, &ids, ×tamps) +// assert.NoError(t, err) +// +// // 5. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_SegmentSearch(t *testing.T) { // ctx := context.Background() // // 1. Construct node, collection, partition and segment // pulsarUrl := "pulsar://localhost:6650" @@ -614,159 +220,307 @@ func TestSegment_segmentDelete(t *testing.T) { // // node.Close() //} +// +//func TestSegment_SegmentPreInsert(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Do PreInsert +// var offset = segment.segmentPreInsert(10) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 3. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_SegmentPreDelete(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Do PreDelete +// var offset = segment.segmentPreDelete(10) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 3. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_GetRowCount(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Create ids and timestamps +// ids := []int64{1, 2, 3} +// timestamps := []uint64{0, 0, 0} +// +// // 3. Create records, use schema below: +// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); +// // schema_tmp->AddField("age", DataType::INT32); +// const DIM = 16 +// const N = 3 +// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +// var rawData []byte +// for _, ele := range vec { +// buf := make([]byte, 4) +// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) +// rawData = append(rawData, buf...) +// } +// bs := make([]byte, 4) +// binary.LittleEndian.PutUint32(bs, 1) +// rawData = append(rawData, bs...) +// var records []*commonpb.Blob +// for i := 0; i < N; i++ { +// blob := &commonpb.Blob{ +// Value: rawData, +// } +// records = append(records, blob) +// } +// +// // 4. Do PreInsert +// var offset = segment.segmentPreInsert(N) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 5. Do Insert +// var err = segment.segmentInsert(offset, &ids, ×tamps, &records) +// assert.NoError(t, err) +// +// // 6. Get segment row count +// var rowCount = segment.getRowCount() +// assert.Equal(t, rowCount, int64(len(ids))) +// +// // 7. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_GetDeletedCount(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Create ids and timestamps +// ids := []int64{1, 2, 3} +// timestamps := []uint64{0, 0, 0} +// +// // 3. Do PreDelete +// var offset = segment.segmentPreDelete(10) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 4. Do Delete +// var err = segment.segmentDelete(offset, &ids, ×tamps) +// assert.NoError(t, err) +// +// // 5. Get segment deleted count +// var deletedCount = segment.getDeletedCount() +// // TODO: assert.Equal(t, deletedCount, len(ids)) +// assert.Equal(t, deletedCount, int64(0)) +// +// // 6. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} +// +//func TestSegment_GetMemSize(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", "") +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Create ids and timestamps +// ids := []int64{1, 2, 3} +// timestamps := []uint64{0, 0, 0} +// +// // 3. Create records, use schema below: +// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); +// // schema_tmp->AddField("age", DataType::INT32); +// const DIM = 16 +// const N = 3 +// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +// var rawData []byte +// for _, ele := range vec { +// buf := make([]byte, 4) +// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) +// rawData = append(rawData, buf...) +// } +// bs := make([]byte, 4) +// binary.LittleEndian.PutUint32(bs, 1) +// rawData = append(rawData, bs...) +// var records []*commonpb.Blob +// for i := 0; i < N; i++ { +// blob := &commonpb.Blob{ +// Value: rawData, +// } +// records = append(records, blob) +// } +// +// // 4. Do PreInsert +// var offset = segment.segmentPreInsert(N) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 5. Do Insert +// var err = segment.segmentInsert(offset, &ids, ×tamps, &records) +// assert.NoError(t, err) +// +// // 6. Get memory usage in bytes +// var memSize = segment.getMemSize() +// assert.Equal(t, memSize, int64(2785280)) +// +// // 7. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} -//-------------------------------------------------------------------------------------- preDm functions -func TestSegment_segmentPreInsert(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offset = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) -} - -func TestSegment_segmentPreDelete(t *testing.T) { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - DataType: schemapb.DataType_INT32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "1", - }, - }, - } - - schema := schemapb.CollectionSchema{ - Name: "collection0", - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - Id: UniqueID(0), - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIds: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - - collectionMetaBlob := proto.MarshalTextString(&collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - collection := newCollection(&collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.Id, UniqueID(0)) - - segmentID := UniqueID(0) - segment := newSegment(collection, segmentID) - assert.Equal(t, segmentID, segment.segmentID) - - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} - - const DIM = 16 - const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - var records []*commonpb.Blob - for i := 0; i < N; i++ { - blob := &commonpb.Blob{ - Value: rawData, - } - records = append(records, blob) - } - - var offsetInsert = segment.segmentPreInsert(N) - assert.GreaterOrEqual(t, offsetInsert, int64(0)) - - var err = segment.segmentInsert(offsetInsert, &ids, ×tamps, &records) - assert.NoError(t, err) - - var offsetDelete = segment.segmentPreDelete(10) - assert.GreaterOrEqual(t, offsetDelete, int64(0)) -} +//func TestSegment_RealSchemaTest(t *testing.T) { +// ctx := context.Background() +// // 1. Construct node, collection, partition and segment +// var schemaString = "id: 6875229265736357360\nname: \"collection0\"\nschema: \u003c\n " + +// "field_metas: \u003c\n field_name: \"field_3\"\n type: INT32\n dim: 1\n \u003e\n " + +// "field_metas: \u003c\n field_name: \"field_vec\"\n type: VECTOR_FLOAT\n dim: 16\n " + +// "\u003e\n\u003e\ncreate_time: 1600764055\nsegment_ids: 6875229265736357360\npartition_tags: \"default\"\n" +// pulsarUrl := "pulsar://localhost:6650" +// node := NewQueryNode(ctx, 0, pulsarUrl) +// var collection = node.newCollection(0, "collection0", schemaString) +// var partition = collection.newPartition("partition0") +// var segment = partition.newSegment(0) +// +// node.SegmentsMap[int64(0)] = segment +// +// assert.Equal(t, collection.CollectionName, "collection0") +// assert.Equal(t, partition.partitionTag, "partition0") +// assert.Equal(t, segment.SegmentID, int64(0)) +// assert.Equal(t, len(node.SegmentsMap), 1) +// +// // 2. Create ids and timestamps +// ids := []int64{1, 2, 3} +// timestamps := []uint64{0, 0, 0} +// +// // 3. Create records, use schema below: +// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); +// // schema_tmp->AddField("age", DataType::INT32); +// const DIM = 16 +// const N = 3 +// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +// var rawData []byte +// for _, ele := range vec { +// buf := make([]byte, 4) +// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) +// rawData = append(rawData, buf...) +// } +// bs := make([]byte, 4) +// binary.LittleEndian.PutUint32(bs, 1) +// rawData = append(rawData, bs...) +// var records []*commonpb.Blob +// for i := 0; i < N; i++ { +// blob := &commonpb.Blob { +// Value: rawData, +// } +// records = append(records, blob) +// } +// +// // 4. Do PreInsert +// var offset = segment.segmentPreInsert(N) +// assert.GreaterOrEqual(t, offset, int64(0)) +// +// // 5. Do Insert +// var err = segment.segmentInsert(offset, &ids, ×tamps, &records) +// assert.NoError(t, err) +// +// // 6. Destruct collection, partition and segment +// partition.deleteSegment(node, segment) +// collection.deletePartition(node, partition) +// node.deleteCollection(collection) +// +// assert.Equal(t, len(node.Collections), 0) +// assert.Equal(t, len(node.SegmentsMap), 0) +// +// node.Close() +//} diff --git a/internal/util/tsoutil/tso.go b/internal/util/tsoutil/tso.go index c1e3b3491b..625c16635d 100644 --- a/internal/util/tsoutil/tso.go +++ b/internal/util/tsoutil/tso.go @@ -1,14 +1,7 @@ package tsoutil import ( - "fmt" - "path" - "strconv" "time" - - "github.com/zilliztech/milvus-distributed/internal/conf" - "github.com/zilliztech/milvus-distributed/internal/kv" - "go.etcd.io/etcd/clientv3" ) const ( @@ -27,15 +20,3 @@ func ParseTS(ts uint64) (time.Time, uint64) { physicalTime := time.Unix(int64(physical/1000), int64(physical)%1000*time.Millisecond.Nanoseconds()) return physicalTime, logical } - -func NewTSOKVBase(subPath string) *kv.EtcdKV { - etcdAddr := conf.Config.Etcd.Address - etcdAddr += ":" - etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10) - fmt.Println("etcdAddr ::: ", etcdAddr) - client, _ := clientv3.New(clientv3.Config{ - Endpoints: []string{etcdAddr}, - DialTimeout: 5 * time.Second, - }) - return kv.NewEtcdKV(client, path.Join(conf.Config.Etcd.Rootpath, subPath)) -} diff --git a/scripts/README.md b/scripts/README.md index 5407d6bb3d..cdd5655119 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -37,7 +37,7 @@ cd milvus-distributed pwd_dir=`pwd` export PATH=$PATH:$(go env GOPATH)/bin - export protoc=${pwd_dir}/internal/core/cmake_build/thirdparty/protobuf/protobuf-build/protoc + export protoc=${pwd_dir}/cmake_build/thirdparty/protobuf/protobuf-build/protoc ./ci/scripts/proto_gen_go.sh ```