milvus/internal/querynodev2/delegator/idf_oracle.go

264 lines
6.6 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"fmt"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
)
type IDFOracle interface {
// Activate(segmentID int64, state commonpb.SegmentState) error
// Deactivate(segmentID int64, state commonpb.SegmentState) error
SyncDistribution(snapshot *snapshot)
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
Remove(segmentID int64, state commonpb.SegmentState)
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
}
type bm25Stats struct {
stats map[int64]*storage.BM25Stats
activate bool
targetVersion int64
}
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
for fieldID, newstats := range stats {
if stats, ok := s.stats[fieldID]; ok {
stats.Merge(newstats)
} else {
s.stats[fieldID] = storage.NewBM25Stats()
s.stats[fieldID].Merge(newstats)
}
}
}
func (s *bm25Stats) Minus(stats map[int64]*storage.BM25Stats) {
for fieldID, newstats := range stats {
if stats, ok := s.stats[fieldID]; ok {
stats.Minus(newstats)
} else {
log.Panic("minus failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
}
}
}
func (s *bm25Stats) GetStats(fieldID int64) (*storage.BM25Stats, error) {
stats, ok := s.stats[fieldID]
if !ok {
return nil, fmt.Errorf("field not found in idf oracle BM25 stats")
}
return stats, nil
}
func (s *bm25Stats) NumRow() int64 {
for _, stats := range s.stats {
return stats.NumRow()
}
return 0
}
func newBm25Stats(functions []*schemapb.FunctionSchema) *bm25Stats {
stats := &bm25Stats{
stats: make(map[int64]*storage.BM25Stats),
}
for _, function := range functions {
if function.GetType() == schemapb.FunctionType_BM25 {
stats.stats[function.GetOutputFieldIds()[0]] = storage.NewBM25Stats()
}
}
return stats
}
type idfOracle struct {
sync.RWMutex
current *bm25Stats
growing map[int64]*bm25Stats
sealed map[int64]*bm25Stats
targetVersion int64
}
func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) {
o.Lock()
defer o.Unlock()
switch state {
case segments.SegmentTypeGrowing:
if _, ok := o.growing[segmentID]; ok {
return
}
o.growing[segmentID] = &bm25Stats{
stats: stats,
activate: true,
targetVersion: initialTargetVersion,
}
o.current.Merge(stats)
case segments.SegmentTypeSealed:
if _, ok := o.sealed[segmentID]; ok {
return
}
o.sealed[segmentID] = &bm25Stats{
stats: stats,
activate: false,
targetVersion: initialTargetVersion,
}
default:
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
return
}
}
func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) {
if len(stats) == 0 {
return
}
o.Lock()
defer o.Unlock()
old, ok := o.growing[segmentID]
if !ok {
return
}
old.Merge(stats)
if old.activate {
o.current.Merge(stats)
}
}
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) {
o.Lock()
defer o.Unlock()
switch state {
case segments.SegmentTypeGrowing:
if stats, ok := o.growing[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.growing, segmentID)
}
case segments.SegmentTypeSealed:
if stats, ok := o.sealed[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.sealed, segmentID)
}
default:
return
}
}
func (o *idfOracle) activate(stats *bm25Stats) {
stats.activate = true
o.current.Merge(stats.stats)
}
func (o *idfOracle) deactivate(stats *bm25Stats) {
stats.activate = false
o.current.Minus(stats.stats)
}
func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
o.Lock()
defer o.Unlock()
sealed, growing := snapshot.Peek()
for _, item := range sealed {
for _, segment := range item.Segments {
if stats, ok := o.sealed[segment.SegmentID]; ok {
stats.targetVersion = segment.TargetVersion
} else {
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
}
}
}
for _, segment := range growing {
if stats, ok := o.growing[segment.SegmentID]; ok {
stats.targetVersion = segment.TargetVersion
} else {
log.Warn("idf oracle lack some growing segment", zap.Int64("segmentID", segment.SegmentID))
}
}
o.targetVersion = snapshot.targetVersion
for _, stats := range o.sealed {
if !stats.activate && stats.targetVersion == o.targetVersion {
o.activate(stats)
} else if stats.activate && stats.targetVersion != o.targetVersion {
o.deactivate(stats)
}
}
for _, stats := range o.growing {
if !stats.activate && (stats.targetVersion == o.targetVersion || stats.targetVersion == initialTargetVersion) {
o.activate(stats)
} else if stats.activate && (stats.targetVersion != o.targetVersion && stats.targetVersion != initialTargetVersion) {
o.deactivate(stats)
}
}
log.Debug("sync distribution finished", zap.Int64("version", o.targetVersion), zap.Int64("numrow", o.current.NumRow()))
}
func (o *idfOracle) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) {
o.RLock()
defer o.RUnlock()
stats, err := o.current.GetStats(fieldID)
if err != nil {
return nil, 0, err
}
idfBytes := make([][]byte, len(tfs.GetContents()))
for i, tf := range tfs.GetContents() {
idf := stats.BuildIDF(tf)
idfBytes[i] = idf
}
return idfBytes, stats.GetAvgdl(), nil
}
func NewIDFOracle(functions []*schemapb.FunctionSchema) IDFOracle {
return &idfOracle{
current: newBm25Stats(functions),
growing: make(map[int64]*bm25Stats),
sealed: make(map[int64]*bm25Stats),
}
}