mirror of https://github.com/milvus-io/milvus.git
364 lines
10 KiB
Go
364 lines
10 KiB
Go
package executor
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
|
|
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
|
|
|
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
|
"github.com/milvus-io/milvus/pkg/common"
|
|
)
|
|
|
|
func prepareSearchParams(limitClause *planner.NodeLimitClause, annsClause *planner.NodeANNSClause) map[string]string {
|
|
searchParams := map[string]string{
|
|
common.AnnsFieldKey: annsClause.Column.Name,
|
|
common.TopKKey: fmt.Sprintf("%d", limitClause.Limit),
|
|
common.OffsetKey: fmt.Sprintf("%d", limitClause.Offset),
|
|
}
|
|
|
|
if annsClause.Params.IsSome() {
|
|
for k, v := range annsClause.Params.Unwrap().KVs {
|
|
searchParams[k] = v
|
|
}
|
|
params, _ := json.Marshal(annsClause.Params.Unwrap().KVs)
|
|
searchParams[common.SearchParamsKey] = string(params)
|
|
}
|
|
|
|
return searchParams
|
|
}
|
|
|
|
func generateOutputsIndex(outputs []string) (index map[string]int, users []string) {
|
|
index = make(map[string]int, len(outputs))
|
|
users = make([]string, 0, len(outputs))
|
|
|
|
fix := func(s string) string {
|
|
return strings.ToLower(strings.TrimSpace(s))
|
|
}
|
|
|
|
for i, f := range outputs {
|
|
fixed := fix(f)
|
|
if fixed != common.QueryNumberKey && fixed != common.DistanceKey {
|
|
users = append(users, f)
|
|
index[f] = i
|
|
} else {
|
|
index[fixed] = i
|
|
}
|
|
}
|
|
|
|
return index, users
|
|
}
|
|
|
|
func restoreExpr(from *planner.NodeFromClause) string {
|
|
filter := ""
|
|
if from.Where.IsSome() {
|
|
filter = planner.NewExprTextRestorer().RestoreExprText(from.Where.Unwrap())
|
|
}
|
|
return filter
|
|
}
|
|
|
|
func vector2PlaceholderGroupBytes(vectors []*planner.NodeVector) []byte {
|
|
phg := &commonpb.PlaceholderGroup{
|
|
Placeholders: []*commonpb.PlaceholderValue{
|
|
vector2Placeholder(vectors),
|
|
},
|
|
}
|
|
|
|
bs, _ := proto.Marshal(phg)
|
|
return bs
|
|
}
|
|
|
|
func vector2Placeholder(vectors []*planner.NodeVector) *commonpb.PlaceholderValue {
|
|
var placeHolderType commonpb.PlaceholderType
|
|
|
|
ph := &commonpb.PlaceholderValue{
|
|
Tag: "$0",
|
|
Values: make([][]byte, 0, len(vectors)),
|
|
}
|
|
|
|
if len(vectors) == 0 {
|
|
return ph
|
|
}
|
|
|
|
if vectors[0].FloatVector.IsSome() {
|
|
placeHolderType = commonpb.PlaceholderType_FloatVector
|
|
} else {
|
|
placeHolderType = commonpb.PlaceholderType_BinaryVector
|
|
|
|
}
|
|
|
|
ph.Type = placeHolderType
|
|
for _, vector := range vectors {
|
|
ph.Values = append(ph.Values, vector.Serialize())
|
|
}
|
|
|
|
return ph
|
|
}
|
|
|
|
func prepareSearchReq(tableName string, filter string, vectors []*planner.NodeVector, userOutputs []string, searchParams map[string]string) *milvuspb.SearchRequest {
|
|
phg := vector2PlaceholderGroupBytes(vectors)
|
|
req := &milvuspb.SearchRequest{
|
|
CollectionName: tableName,
|
|
PartitionNames: nil,
|
|
Dsl: filter,
|
|
PlaceholderGroup: phg,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
OutputFields: userOutputs,
|
|
SearchParams: funcutil.Map2KeyValuePair(searchParams),
|
|
TravelTimestamp: 0,
|
|
GuaranteeTimestamp: 2, // default bounded consistency level.
|
|
Nq: int64(len(vectors)),
|
|
}
|
|
return req
|
|
}
|
|
|
|
func wrapFieldsData(collectionName string, fieldsData []*schemapb.FieldData) *sqltypes.Result {
|
|
nColumn := len(fieldsData)
|
|
fields := make([]*querypb.Field, 0, nColumn)
|
|
|
|
if nColumn <= 0 {
|
|
return &sqltypes.Result{}
|
|
}
|
|
|
|
for i := 0; i < nColumn; i++ {
|
|
fields = append(fields, getSQLField(collectionName, fieldsData[i]))
|
|
}
|
|
|
|
nRow := typeutil.GetRowCount(fieldsData[0])
|
|
rows := make([][]sqltypes.Value, 0, nRow)
|
|
for i := 0; i < nRow; i++ {
|
|
row := make([]sqltypes.Value, 0, nColumn)
|
|
for j := 0; j < nColumn; j++ {
|
|
row = append(row, getDataSingle(fieldsData[j], i))
|
|
}
|
|
rows = append(rows, row)
|
|
}
|
|
|
|
return &sqltypes.Result{
|
|
Fields: fields,
|
|
Rows: rows,
|
|
}
|
|
}
|
|
|
|
func wrapQueryResults(res *milvuspb.QueryResults) *sqltypes.Result {
|
|
return wrapFieldsData(res.GetCollectionName(), res.GetFieldsData())
|
|
}
|
|
|
|
func wrapCountResult(rowCnt int, column string) *sqltypes.Result {
|
|
result1 := &sqltypes.Result{
|
|
Fields: []*querypb.Field{
|
|
{
|
|
Name: column,
|
|
Type: querypb.Type_INT64,
|
|
},
|
|
},
|
|
Rows: [][]sqltypes.Value{
|
|
{
|
|
sqltypes.NewInt64(int64(rowCnt)),
|
|
},
|
|
},
|
|
}
|
|
return result1
|
|
}
|
|
|
|
func getSQLField(tableName string, fieldData *schemapb.FieldData) *querypb.Field {
|
|
return &querypb.Field{
|
|
Name: fieldData.GetFieldName(),
|
|
Type: toSQLType(fieldData.GetType()),
|
|
Table: tableName,
|
|
OrgTable: "",
|
|
Database: "",
|
|
OrgName: "",
|
|
ColumnLength: 0,
|
|
Charset: 0,
|
|
Decimals: 0,
|
|
Flags: 0,
|
|
}
|
|
}
|
|
|
|
func toSQLType(t schemapb.DataType) querypb.Type {
|
|
switch t {
|
|
case schemapb.DataType_Bool:
|
|
// TODO: tinyint
|
|
return querypb.Type_UINT8
|
|
case schemapb.DataType_Int8:
|
|
return querypb.Type_INT8
|
|
case schemapb.DataType_Int16:
|
|
return querypb.Type_INT16
|
|
case schemapb.DataType_Int32:
|
|
return querypb.Type_INT32
|
|
case schemapb.DataType_Int64:
|
|
return querypb.Type_INT64
|
|
case schemapb.DataType_Float:
|
|
return querypb.Type_FLOAT32
|
|
case schemapb.DataType_Double:
|
|
return querypb.Type_FLOAT64
|
|
case schemapb.DataType_VarChar:
|
|
return querypb.Type_VARCHAR
|
|
// TODO: vector.
|
|
default:
|
|
return querypb.Type_NULL_TYPE
|
|
}
|
|
}
|
|
|
|
func getDataSingle(fieldData *schemapb.FieldData, idx int) sqltypes.Value {
|
|
switch fieldData.GetType() {
|
|
case schemapb.DataType_Bool:
|
|
// TODO: tinyint
|
|
return sqltypes.NewInt32(1)
|
|
case schemapb.DataType_Int8:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
|
return sqltypes.MakeTrusted(sqltypes.Int8, strconv.AppendInt(nil, int64(v), 10))
|
|
case schemapb.DataType_Int16:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
|
return sqltypes.MakeTrusted(sqltypes.Int16, strconv.AppendInt(nil, int64(v), 10))
|
|
case schemapb.DataType_Int32:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
|
return sqltypes.MakeTrusted(sqltypes.Int32, strconv.AppendInt(nil, int64(v), 10))
|
|
case schemapb.DataType_Int64:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_LongData).LongData.GetData()[idx]
|
|
return sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, v, 10))
|
|
case schemapb.DataType_Float:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_FloatData).FloatData.GetData()[idx]
|
|
return sqltypes.MakeTrusted(sqltypes.Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64))
|
|
case schemapb.DataType_Double:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_DoubleData).DoubleData.GetData()[idx]
|
|
return sqltypes.MakeTrusted(sqltypes.Float64, strconv.AppendFloat(nil, v, 'g', -1, 64))
|
|
case schemapb.DataType_VarChar:
|
|
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_StringData).StringData.GetData()[idx]
|
|
return sqltypes.NewVarChar(v)
|
|
|
|
// TODO: vector.
|
|
default:
|
|
// TODO: should raise error here.
|
|
return sqltypes.NewInt32(1)
|
|
}
|
|
}
|
|
|
|
func wrapSearchResult(res *milvuspb.SearchResults, outputsIndex map[string]int, userOutputs []string) *sqltypes.Result {
|
|
qnIndex, qnOk := outputsIndex[common.QueryNumberKey]
|
|
disIndex, disOk := outputsIndex[common.DistanceKey]
|
|
|
|
fieldsData := make([]*schemapb.FieldData, len(res.GetResults().GetFieldsData()))
|
|
copy(fieldsData, res.GetResults().GetFieldsData())
|
|
|
|
insert := func(fields []*schemapb.FieldData, field *schemapb.FieldData, index int) []*schemapb.FieldData {
|
|
if index >= len(fields) {
|
|
return append(fields, field)
|
|
}
|
|
return append(fieldsData[:index], append([]*schemapb.FieldData{field}, fieldsData[index:]...)...)
|
|
}
|
|
|
|
if qnOk && disOk {
|
|
qnField := generateQueryNumberFieldData(res)
|
|
disField := generateDistanceFieldsData(res)
|
|
|
|
index1, index2 := qnIndex, disIndex
|
|
f1, f2 := qnField, disField
|
|
if qnIndex > disIndex {
|
|
index1, index2 = disIndex, qnIndex
|
|
f1, f2 = disField, qnField
|
|
}
|
|
|
|
fieldsData = insert(fieldsData, f1, index1)
|
|
fieldsData = insert(fieldsData, f2, index2)
|
|
} else if qnOk {
|
|
qnField := generateQueryNumberFieldData(res)
|
|
fieldsData = insert(fieldsData, qnField, qnIndex)
|
|
} else if disOk {
|
|
disField := generateDistanceFieldsData(res)
|
|
fieldsData = insert(fieldsData, disField, disIndex)
|
|
}
|
|
|
|
return wrapFieldsData(res.GetCollectionName(), fieldsData)
|
|
}
|
|
|
|
func generateQueryNumberFieldData(res *milvuspb.SearchResults) *schemapb.FieldData {
|
|
arr := make([]int64, 0, res.GetResults().GetNumQueries()*res.GetResults().GetTopK())
|
|
|
|
for i, topk := range res.GetResults().GetTopks() {
|
|
for j := 0; int64(j) < topk; j++ {
|
|
arr = append(arr, int64(i))
|
|
}
|
|
}
|
|
|
|
return &schemapb.FieldData{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldName: common.QueryNumberKey,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: arr,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func generateDistanceFieldsData(res *milvuspb.SearchResults) *schemapb.FieldData {
|
|
return &schemapb.FieldData{
|
|
Type: schemapb.DataType_Float,
|
|
FieldName: common.DistanceKey,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_FloatData{
|
|
FloatData: &schemapb.FloatArray{
|
|
Data: res.GetResults().GetScores(),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func getOutputFieldsOrMatchCountRule(fields []*planner.NodeSelectElement) (outputFields []string, match bool, err error) {
|
|
match = false
|
|
l := len(fields)
|
|
|
|
if l == 1 {
|
|
entry := fields[0]
|
|
match = entry.FunctionCall.IsSome() &&
|
|
entry.FunctionCall.Unwrap().Agg.IsSome() &&
|
|
entry.FunctionCall.Unwrap().Agg.Unwrap().AggCount.IsSome()
|
|
}
|
|
|
|
if match {
|
|
return nil, match, nil
|
|
}
|
|
|
|
outputFields = make([]string, 0, l)
|
|
for _, entry := range fields {
|
|
if entry.Star.IsSome() {
|
|
// TODO: support `select *`.
|
|
return nil, match, fmt.Errorf("* is not supported")
|
|
}
|
|
|
|
if entry.FunctionCall.IsSome() {
|
|
return nil, match, fmt.Errorf("combined select elements is not supported")
|
|
}
|
|
|
|
if entry.FullColumnName.IsSome() {
|
|
c := entry.FullColumnName.Unwrap()
|
|
if c.Alias.IsSome() {
|
|
return nil, match, fmt.Errorf("alias for select elements is not supported")
|
|
}
|
|
outputFields = append(outputFields, c.Name)
|
|
}
|
|
}
|
|
|
|
return outputFields, false, nil
|
|
}
|