mirror of https://github.com/milvus-io/milvus.git
Parallelize the processing of loading binlog and saving index files
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/4973/head^2
parent
30cc84b164
commit
41e1975611
|
@ -3,6 +3,7 @@ package indexnode
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
@ -228,15 +229,23 @@ func (it *IndexBuildTask) Execute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
toLoadDataPaths := it.req.GetDataPaths()
|
||||
keys := make([]string, 0)
|
||||
blobs := make([]*Blob, 0)
|
||||
for _, path := range toLoadDataPaths {
|
||||
keys = append(keys, getKeyByPathNaive(path))
|
||||
blob, err := getBlobByPath(path)
|
||||
keys := make([]string, len(toLoadDataPaths))
|
||||
blobs := make([]*Blob, len(toLoadDataPaths))
|
||||
|
||||
loadKey := func(idx int) error {
|
||||
keys[idx] = getKeyByPathNaive(toLoadDataPaths[idx])
|
||||
blob, err := getBlobByPath(toLoadDataPaths[idx])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
blobs = append(blobs, blob)
|
||||
|
||||
blobs[idx] = blob
|
||||
|
||||
return nil
|
||||
}
|
||||
err = funcutil.ProcessFuncParallel(len(toLoadDataPaths), runtime.NumCPU(), loadKey, "loadKey")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storageBlobs := getStorageBlobs(blobs)
|
||||
|
@ -295,15 +304,25 @@ func (it *IndexBuildTask) Execute(ctx context.Context) error {
|
|||
return it.kv.Save(path, string(value))
|
||||
}
|
||||
|
||||
it.savePaths = make([]string, 0)
|
||||
for _, blob := range serializedIndexBlobs {
|
||||
it.savePaths = make([]string, len(serializedIndexBlobs))
|
||||
saveIndexFile := func(idx int) error {
|
||||
blob := serializedIndexBlobs[idx]
|
||||
key, value := blob.Key, blob.Value
|
||||
|
||||
savePath := getSavePathByKey(key)
|
||||
|
||||
err := saveBlob(savePath, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
it.savePaths = append(it.savePaths, savePath)
|
||||
|
||||
it.savePaths[idx] = savePath
|
||||
|
||||
return nil
|
||||
}
|
||||
err = funcutil.ProcessFuncParallel(len(serializedIndexBlobs), runtime.NumCPU(), saveIndexFile, "saveIndexFile")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// err = it.index.Delete()
|
||||
|
|
|
@ -965,8 +965,8 @@ func (node *ProxyNode) Insert(ctx context.Context, request *milvuspb.InsertReque
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName),
|
||||
zap.Any("row data", "too many and too big, ignored"),
|
||||
zap.Any("hash keys", "too many, ignored"))
|
||||
zap.Any("len(RowData)", len(it.RowData)),
|
||||
zap.Any("len(RowIDs)", len(it.RowIDs)))
|
||||
defer func() {
|
||||
log.Debug("Insert Done",
|
||||
zap.Error(err),
|
||||
|
@ -976,8 +976,8 @@ func (node *ProxyNode) Insert(ctx context.Context, request *milvuspb.InsertReque
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName),
|
||||
zap.Any("row data", "too many and too big, ignored"),
|
||||
zap.Any("hash keys", "too many, ignored"))
|
||||
zap.Any("len(RowData)", len(it.RowData)),
|
||||
zap.Any("len(RowIDs)", len(it.RowIDs)))
|
||||
}()
|
||||
|
||||
err = it.WaitToFinish()
|
||||
|
@ -1027,7 +1027,7 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("dsl", request.Dsl),
|
||||
zap.Any("placeholder group", "too many and too big, ignored"))
|
||||
zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)))
|
||||
defer func() {
|
||||
log.Debug("Search Done",
|
||||
zap.Error(err),
|
||||
|
@ -1038,7 +1038,7 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("dsl", request.Dsl),
|
||||
zap.Any("placeholder group", "too many and too big, ignored"))
|
||||
zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)))
|
||||
}()
|
||||
|
||||
err = qt.WaitToFinish()
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -617,83 +619,10 @@ func (st *SearchTask) Execute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits))
|
||||
|
||||
var err error
|
||||
for i := range partialSearchResult.Hits {
|
||||
j := i
|
||||
|
||||
func(idx int) {
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
|
||||
}
|
||||
partialHits[idx] = partialHit
|
||||
}(j)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hits = append(hits, partialHits)
|
||||
}
|
||||
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func decodeSearchResultsParallel(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
// necessary to parallel this?
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
||||
continue
|
||||
}
|
||||
func decodeSearchResultsParallel(searchResults []*internalpb.SearchResults, maxParallel int) ([][]*milvuspb.Hits, error) {
|
||||
log.Debug("decodeSearchResultsParallel", zap.Any("NumOfGoRoutines", maxParallel))
|
||||
|
||||
// necessary to check nq (len(partialSearchResult.Hits) here)?
|
||||
partialHits := make([]*milvuspb.Hits, len(partialSearchResult.Hits))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var err error
|
||||
for i := range partialSearchResult.Hits {
|
||||
j := i
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
|
||||
}
|
||||
partialHits[idx] = partialHit
|
||||
}(j)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
hits = append(hits, partialHits)
|
||||
}
|
||||
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
log.Debug("ProcessSearchResultParallel", zap.Any("runtime.NumCPU", runtime.NumCPU()))
|
||||
hits := make([][]*milvuspb.Hits, 0)
|
||||
// necessary to parallel this?
|
||||
for _, partialSearchResult := range searchResults {
|
||||
|
@ -702,279 +631,154 @@ func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults)
|
|||
}
|
||||
|
||||
nq := len(partialSearchResult.Hits)
|
||||
maxParallel := runtime.NumCPU()
|
||||
nqPerBatch := (nq + maxParallel - 1) / maxParallel
|
||||
partialHits := make([]*milvuspb.Hits, nq)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var err error
|
||||
for i := 0; i < nq; i = i + nqPerBatch {
|
||||
j := i
|
||||
wg.Add(1)
|
||||
f := func(idx int) error {
|
||||
partialHit := &milvuspb.Hits{}
|
||||
|
||||
go func(begin int) {
|
||||
defer wg.Done()
|
||||
end := getMin(nq, begin+nqPerBatch)
|
||||
for idx := begin; idx < end; idx++ {
|
||||
partialHit := &milvuspb.Hits{}
|
||||
err = proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.Any("unmarshal search result error", err))
|
||||
}
|
||||
partialHits[idx] = partialHit
|
||||
}
|
||||
}(j)
|
||||
err := proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partialHits[idx] = partialHit
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := funcutil.ProcessFuncParallel(nq, maxParallel, f, "decodePartialSearchResult")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
hits = append(hits, partialHits)
|
||||
}
|
||||
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
return decodeSearchResultsParallel(searchResults, 1)
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func decodeSearchResultsParallelByNq(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
if len(searchResults) <= 0 {
|
||||
return nil, errors.New("no need to decode empty search results")
|
||||
}
|
||||
nq := len(searchResults[0].Hits)
|
||||
return decodeSearchResultsParallel(searchResults, nq)
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
return decodeSearchResultsParallel(searchResults, runtime.NumCPU())
|
||||
}
|
||||
|
||||
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
||||
return decodeSearchResultsParallel(searchResults)
|
||||
t := time.Now()
|
||||
defer func() {
|
||||
log.Debug("decodeSearchResults", zap.Any("time cost", time.Since(t)))
|
||||
}()
|
||||
return decodeSearchResultsParallelByCPU(searchResults)
|
||||
}
|
||||
|
||||
func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) *milvuspb.SearchResults {
|
||||
log.Debug("reduceSearchResultsParallel", zap.Any("NumOfGoRoutines", maxParallel))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
f := func(idx int) error {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := funcutil.ProcessFuncParallel(nq, maxParallel, f, "reduceSearchResults")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
for i := 0; i < nq; i++ {
|
||||
j := i
|
||||
func(idx int) {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
}(j)
|
||||
}
|
||||
|
||||
return ret
|
||||
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType, 1)
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with simple serial implementation
|
||||
func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nq; i++ {
|
||||
j := i
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
}(j)
|
||||
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return ret
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func reduceSearchResultsParallelByNq(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType, nq)
|
||||
}
|
||||
|
||||
// TODO: add benchmark to compare with simple serial implementation
|
||||
// TODO: add benchmark to compare with serial implementation
|
||||
func reduceSearchResultsParallelByCPU(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
maxParallel := runtime.NumCPU()
|
||||
nqPerBatch := (nq + maxParallel - 1) / maxParallel
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
Hits: make([][]byte, nq),
|
||||
}
|
||||
|
||||
const minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for begin := 0; begin < nq; begin = begin + nqPerBatch {
|
||||
j := begin
|
||||
|
||||
wg.Add(1)
|
||||
go func(begin int) {
|
||||
defer wg.Done()
|
||||
|
||||
end := getMin(nq, begin+nqPerBatch)
|
||||
|
||||
for idx := begin; idx < end; idx++ {
|
||||
locs := make([]int, availableQueryNodeNum)
|
||||
reducedHits := &milvuspb.Hits{
|
||||
IDs: make([]int64, 0),
|
||||
RowData: make([][]byte, 0),
|
||||
Scores: make([]float32, 0),
|
||||
}
|
||||
|
||||
for j := 0; j < topk; j++ {
|
||||
valid := false
|
||||
choice, maxDistance := 0, minFloat32
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
if loc >= len(hits[q][idx].IDs) {
|
||||
continue
|
||||
}
|
||||
distance := hits[q][idx].Scores[loc]
|
||||
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
||||
choice = q
|
||||
maxDistance = distance
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
||||
break
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
||||
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
||||
}
|
||||
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
|
||||
if metricType != "IP" {
|
||||
for k := range reducedHits.Scores {
|
||||
reducedHits.Scores[k] *= -1
|
||||
}
|
||||
}
|
||||
|
||||
reducedHitsBs, err := proto.Marshal(reducedHits)
|
||||
if err != nil {
|
||||
log.Debug("proxynode", zap.String("error", "marshal error"))
|
||||
}
|
||||
|
||||
ret.Hits[idx] = reducedHitsBs
|
||||
}
|
||||
}(j)
|
||||
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return ret
|
||||
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType, runtime.NumCPU())
|
||||
}
|
||||
|
||||
func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
||||
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType)
|
||||
t := time.Now()
|
||||
defer func() {
|
||||
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
|
||||
}()
|
||||
return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType)
|
||||
}
|
||||
|
||||
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
||||
|
@ -990,6 +794,10 @@ func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
|||
}
|
||||
|
||||
func (st *SearchTask) PostExecute(ctx context.Context) error {
|
||||
t0 := time.Now()
|
||||
defer func() {
|
||||
log.Debug("WaitAndPostExecute", zap.Any("time cost", time.Since(t0)))
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-st.TraceCtx().Done():
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
package funcutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/log"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func GetFunctionName(i interface{}) string {
|
||||
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
|
||||
}
|
||||
|
||||
// Reference: https://stackoverflow.com/questions/40809504/idiomatic-goroutine-termination-and-error-handling
|
||||
func ProcessFuncParallel(total, maxParallel int, f func(idx int) error, fname string) error {
|
||||
if maxParallel <= 0 {
|
||||
maxParallel = 1
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
defer func() {
|
||||
log.Debug(fname, zap.Any("time cost", time.Since(t)))
|
||||
}()
|
||||
|
||||
nPerBatch := (total + maxParallel - 1) / maxParallel
|
||||
log.Debug(fname, zap.Any("total", total))
|
||||
log.Debug(fname, zap.Any("nPerBatch", nPerBatch))
|
||||
|
||||
quit := make(chan bool)
|
||||
errc := make(chan error)
|
||||
done := make(chan error)
|
||||
getMin := func(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
routineNum := 0
|
||||
for begin := 0; begin < total; begin = begin + nPerBatch {
|
||||
j := begin
|
||||
|
||||
go func(begin int) {
|
||||
err := error(nil)
|
||||
|
||||
end := getMin(total, begin+nPerBatch)
|
||||
for idx := begin; idx < end; idx++ {
|
||||
err = f(idx)
|
||||
if err != nil {
|
||||
log.Debug(fname, zap.Error(err), zap.Any("idx", idx))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ch := done // send to done channel
|
||||
if err != nil {
|
||||
ch = errc // send to error channel
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- err:
|
||||
return
|
||||
case <-quit:
|
||||
return
|
||||
}
|
||||
}(j)
|
||||
|
||||
routineNum++
|
||||
}
|
||||
|
||||
log.Debug(fname, zap.Any("NumOfGoRoutines", routineNum))
|
||||
|
||||
if routineNum <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case err := <-errc:
|
||||
close(quit)
|
||||
return err
|
||||
case <-done:
|
||||
count++
|
||||
if count == routineNum {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
package funcutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProcessFuncParallel(t *testing.T) {
|
||||
total := 64
|
||||
s := make([]int, total)
|
||||
|
||||
expectedS := make([]int, total)
|
||||
for i := range expectedS {
|
||||
expectedS[i] = i
|
||||
}
|
||||
|
||||
naiveF := func(idx int) error {
|
||||
s[idx] = idx
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
err = ProcessFuncParallel(total, 1, naiveF, "naiveF") // serial
|
||||
assert.Equal(t, err, nil, "process function serially must be right")
|
||||
assert.Equal(t, s, expectedS, "process function serially must be right")
|
||||
|
||||
err = ProcessFuncParallel(total, total, naiveF, "naiveF") // Totally Parallel
|
||||
assert.Equal(t, err, nil, "process function parallel must be right")
|
||||
assert.Equal(t, s, expectedS, "process function parallel must be right")
|
||||
|
||||
err = ProcessFuncParallel(total, runtime.NumCPU(), naiveF, "naiveF") // Parallel by CPU
|
||||
assert.Equal(t, err, nil, "process function parallel must be right")
|
||||
assert.Equal(t, s, expectedS, "process function parallel must be right")
|
||||
|
||||
oddErrorF := func(idx int) error {
|
||||
if idx%2 == 1 {
|
||||
return errors.New("odd location: " + strconv.Itoa(idx))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ProcessFuncParallel(total, 1, oddErrorF, "oddErrorF") // serial
|
||||
assert.NotEqual(t, err, nil, "process function serially must be right")
|
||||
|
||||
err = ProcessFuncParallel(total, total, oddErrorF, "oddErrorF") // Totally Parallel
|
||||
assert.NotEqual(t, err, nil, "process function parallel must be right")
|
||||
|
||||
err = ProcessFuncParallel(total, runtime.NumCPU(), oddErrorF, "oddErrorF") // Parallel by CPU
|
||||
assert.NotEqual(t, err, nil, "process function parallel must be right")
|
||||
|
||||
evenErrorF := func(idx int) error {
|
||||
if idx%2 == 0 {
|
||||
return errors.New("even location: " + strconv.Itoa(idx))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ProcessFuncParallel(total, 1, evenErrorF, "evenErrorF") // serial
|
||||
assert.NotEqual(t, err, nil, "process function serially must be right")
|
||||
|
||||
err = ProcessFuncParallel(total, total, evenErrorF, "evenErrorF") // Totally Parallel
|
||||
assert.NotEqual(t, err, nil, "process function parallel must be right")
|
||||
|
||||
err = ProcessFuncParallel(total, runtime.NumCPU(), evenErrorF, "evenErrorF") // Parallel by CPU
|
||||
assert.NotEqual(t, err, nil, "process function parallel must be right")
|
||||
|
||||
// rand.Int() may be always a even number
|
||||
randomErrorF := func(idx int) error {
|
||||
if rand.Int()%2 == 0 {
|
||||
return errors.New("random location: " + strconv.Itoa(idx))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ProcessFuncParallel(total, 1, randomErrorF, "randomErrorF") // serial
|
||||
|
||||
ProcessFuncParallel(total, total, randomErrorF, "randomErrorF") // Totally Parallel
|
||||
|
||||
ProcessFuncParallel(total, runtime.NumCPU(), randomErrorF, "randomErrorF") // Parallel by CPU
|
||||
}
|
Loading…
Reference in New Issue