milvus/internal/util/function/rerank/rerank_base.go

115 lines
3.3 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"fmt"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
const (
reranker string = "reranker"
)
// topk and group related parameters, reranker can choose to process or ignore
type searchParams struct {
limit int64
groupByFieldId int64
groupSize int64
strictGroupSize bool
}
type RerankBase struct {
rerankerName string
isSupportGroup bool
pkType schemapb.DataType
inputFieldIDs []int64
inputFieldNames []string
inputFieldTypes []schemapb.DataType
// TODO: The parameter is passed to the reranker, and the reranker decides whether to implement the parameter
searchParams *searchParams
}
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool) (*RerankBase, error) {
pkType, err := getPKType(coll)
if err != nil {
return nil, err
}
base := RerankBase{
inputFieldNames: funcSchema.InputFieldNames,
rerankerName: rerankerName,
isSupportGroup: isSupportGroup,
pkType: pkType,
}
nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) {
return field.GetName(), field
})
if len(funcSchema.GetOutputFieldNames()) != 0 {
return nil, fmt.Errorf("Rerank function output field names should be empty")
}
for _, name := range funcSchema.GetInputFieldNames() {
if name == "" {
return nil, fmt.Errorf("Rerank input field name cannot be empty string")
}
if lo.Count(funcSchema.GetInputFieldNames(), name) > 1 {
return nil, fmt.Errorf("Each function input field should be used exactly once in the same function, input field: %s", name)
}
inputField, ok := nameMap[name]
if !ok {
return nil, fmt.Errorf("Function input field not found: %s", name)
}
if inputField.GetNullable() {
return nil, fmt.Errorf("Function input field cannot be nullable: field %s", inputField.GetName())
}
base.inputFieldIDs = append(base.inputFieldIDs, inputField.FieldID)
base.inputFieldTypes = append(base.inputFieldTypes, inputField.DataType)
}
return &base, nil
}
func (base *RerankBase) GetInputFieldNames() []string {
return base.inputFieldNames
}
func (base *RerankBase) GetInputFieldTypes() []schemapb.DataType {
return base.inputFieldTypes
}
func (base *RerankBase) GetInputFieldIDs() []int64 {
return base.inputFieldIDs
}
func (base *RerankBase) IsSupportGroup() bool {
return base.isSupportGroup
}
func (base *RerankBase) GetRankName() string {
return base.rerankerName
}