milvus/internal/mysqld/executor/utils.go

364 lines
10 KiB
Go

package executor
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/util/typeutil"
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/milvus-io/milvus/pkg/common"
)
func prepareSearchParams(limitClause *planner.NodeLimitClause, annsClause *planner.NodeANNSClause) map[string]string {
searchParams := map[string]string{
common.AnnsFieldKey: annsClause.Column.Name,
common.TopKKey: fmt.Sprintf("%d", limitClause.Limit),
common.OffsetKey: fmt.Sprintf("%d", limitClause.Offset),
}
if annsClause.Params.IsSome() {
for k, v := range annsClause.Params.Unwrap().KVs {
searchParams[k] = v
}
params, _ := json.Marshal(annsClause.Params.Unwrap().KVs)
searchParams[common.SearchParamsKey] = string(params)
}
return searchParams
}
func generateOutputsIndex(outputs []string) (index map[string]int, users []string) {
index = make(map[string]int, len(outputs))
users = make([]string, 0, len(outputs))
fix := func(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
for i, f := range outputs {
fixed := fix(f)
if fixed != common.QueryNumberKey && fixed != common.DistanceKey {
users = append(users, f)
index[f] = i
} else {
index[fixed] = i
}
}
return index, users
}
func restoreExpr(from *planner.NodeFromClause) string {
filter := ""
if from.Where.IsSome() {
filter = planner.NewExprTextRestorer().RestoreExprText(from.Where.Unwrap())
}
return filter
}
func vector2PlaceholderGroupBytes(vectors []*planner.NodeVector) []byte {
phg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
vector2Placeholder(vectors),
},
}
bs, _ := proto.Marshal(phg)
return bs
}
func vector2Placeholder(vectors []*planner.NodeVector) *commonpb.PlaceholderValue {
var placeHolderType commonpb.PlaceholderType
ph := &commonpb.PlaceholderValue{
Tag: "$0",
Values: make([][]byte, 0, len(vectors)),
}
if len(vectors) == 0 {
return ph
}
if vectors[0].FloatVector.IsSome() {
placeHolderType = commonpb.PlaceholderType_FloatVector
} else {
placeHolderType = commonpb.PlaceholderType_BinaryVector
}
ph.Type = placeHolderType
for _, vector := range vectors {
ph.Values = append(ph.Values, vector.Serialize())
}
return ph
}
func prepareSearchReq(tableName string, filter string, vectors []*planner.NodeVector, userOutputs []string, searchParams map[string]string) *milvuspb.SearchRequest {
phg := vector2PlaceholderGroupBytes(vectors)
req := &milvuspb.SearchRequest{
CollectionName: tableName,
PartitionNames: nil,
Dsl: filter,
PlaceholderGroup: phg,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: userOutputs,
SearchParams: funcutil.Map2KeyValuePair(searchParams),
TravelTimestamp: 0,
GuaranteeTimestamp: 2, // default bounded consistency level.
Nq: int64(len(vectors)),
}
return req
}
func wrapFieldsData(collectionName string, fieldsData []*schemapb.FieldData) *sqltypes.Result {
nColumn := len(fieldsData)
fields := make([]*querypb.Field, 0, nColumn)
if nColumn <= 0 {
return &sqltypes.Result{}
}
for i := 0; i < nColumn; i++ {
fields = append(fields, getSQLField(collectionName, fieldsData[i]))
}
nRow := typeutil.GetRowCount(fieldsData[0])
rows := make([][]sqltypes.Value, 0, nRow)
for i := 0; i < nRow; i++ {
row := make([]sqltypes.Value, 0, nColumn)
for j := 0; j < nColumn; j++ {
row = append(row, getDataSingle(fieldsData[j], i))
}
rows = append(rows, row)
}
return &sqltypes.Result{
Fields: fields,
Rows: rows,
}
}
func wrapQueryResults(res *milvuspb.QueryResults) *sqltypes.Result {
return wrapFieldsData(res.GetCollectionName(), res.GetFieldsData())
}
func wrapCountResult(rowCnt int, column string) *sqltypes.Result {
result1 := &sqltypes.Result{
Fields: []*querypb.Field{
{
Name: column,
Type: querypb.Type_INT64,
},
},
Rows: [][]sqltypes.Value{
{
sqltypes.NewInt64(int64(rowCnt)),
},
},
}
return result1
}
func getSQLField(tableName string, fieldData *schemapb.FieldData) *querypb.Field {
return &querypb.Field{
Name: fieldData.GetFieldName(),
Type: toSQLType(fieldData.GetType()),
Table: tableName,
OrgTable: "",
Database: "",
OrgName: "",
ColumnLength: 0,
Charset: 0,
Decimals: 0,
Flags: 0,
}
}
func toSQLType(t schemapb.DataType) querypb.Type {
switch t {
case schemapb.DataType_Bool:
// TODO: tinyint
return querypb.Type_UINT8
case schemapb.DataType_Int8:
return querypb.Type_INT8
case schemapb.DataType_Int16:
return querypb.Type_INT16
case schemapb.DataType_Int32:
return querypb.Type_INT32
case schemapb.DataType_Int64:
return querypb.Type_INT64
case schemapb.DataType_Float:
return querypb.Type_FLOAT32
case schemapb.DataType_Double:
return querypb.Type_FLOAT64
case schemapb.DataType_VarChar:
return querypb.Type_VARCHAR
// TODO: vector.
default:
return querypb.Type_NULL_TYPE
}
}
func getDataSingle(fieldData *schemapb.FieldData, idx int) sqltypes.Value {
switch fieldData.GetType() {
case schemapb.DataType_Bool:
// TODO: tinyint
return sqltypes.NewInt32(1)
case schemapb.DataType_Int8:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
return sqltypes.MakeTrusted(sqltypes.Int8, strconv.AppendInt(nil, int64(v), 10))
case schemapb.DataType_Int16:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
return sqltypes.MakeTrusted(sqltypes.Int16, strconv.AppendInt(nil, int64(v), 10))
case schemapb.DataType_Int32:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
return sqltypes.MakeTrusted(sqltypes.Int32, strconv.AppendInt(nil, int64(v), 10))
case schemapb.DataType_Int64:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_LongData).LongData.GetData()[idx]
return sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, v, 10))
case schemapb.DataType_Float:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_FloatData).FloatData.GetData()[idx]
return sqltypes.MakeTrusted(sqltypes.Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64))
case schemapb.DataType_Double:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_DoubleData).DoubleData.GetData()[idx]
return sqltypes.MakeTrusted(sqltypes.Float64, strconv.AppendFloat(nil, v, 'g', -1, 64))
case schemapb.DataType_VarChar:
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_StringData).StringData.GetData()[idx]
return sqltypes.NewVarChar(v)
// TODO: vector.
default:
// TODO: should raise error here.
return sqltypes.NewInt32(1)
}
}
func wrapSearchResult(res *milvuspb.SearchResults, outputsIndex map[string]int, userOutputs []string) *sqltypes.Result {
qnIndex, qnOk := outputsIndex[common.QueryNumberKey]
disIndex, disOk := outputsIndex[common.DistanceKey]
fieldsData := make([]*schemapb.FieldData, len(res.GetResults().GetFieldsData()))
copy(fieldsData, res.GetResults().GetFieldsData())
insert := func(fields []*schemapb.FieldData, field *schemapb.FieldData, index int) []*schemapb.FieldData {
if index >= len(fields) {
return append(fields, field)
}
return append(fieldsData[:index], append([]*schemapb.FieldData{field}, fieldsData[index:]...)...)
}
if qnOk && disOk {
qnField := generateQueryNumberFieldData(res)
disField := generateDistanceFieldsData(res)
index1, index2 := qnIndex, disIndex
f1, f2 := qnField, disField
if qnIndex > disIndex {
index1, index2 = disIndex, qnIndex
f1, f2 = disField, qnField
}
fieldsData = insert(fieldsData, f1, index1)
fieldsData = insert(fieldsData, f2, index2)
} else if qnOk {
qnField := generateQueryNumberFieldData(res)
fieldsData = insert(fieldsData, qnField, qnIndex)
} else if disOk {
disField := generateDistanceFieldsData(res)
fieldsData = insert(fieldsData, disField, disIndex)
}
return wrapFieldsData(res.GetCollectionName(), fieldsData)
}
func generateQueryNumberFieldData(res *milvuspb.SearchResults) *schemapb.FieldData {
arr := make([]int64, 0, res.GetResults().GetNumQueries()*res.GetResults().GetTopK())
for i, topk := range res.GetResults().GetTopks() {
for j := 0; int64(j) < topk; j++ {
arr = append(arr, int64(i))
}
}
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: common.QueryNumberKey,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: arr,
},
},
},
},
}
}
func generateDistanceFieldsData(res *milvuspb.SearchResults) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: common.DistanceKey,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: res.GetResults().GetScores(),
},
},
},
},
}
}
func getOutputFieldsOrMatchCountRule(fields []*planner.NodeSelectElement) (outputFields []string, match bool, err error) {
match = false
l := len(fields)
if l == 1 {
entry := fields[0]
match = entry.FunctionCall.IsSome() &&
entry.FunctionCall.Unwrap().Agg.IsSome() &&
entry.FunctionCall.Unwrap().Agg.Unwrap().AggCount.IsSome()
}
if match {
return nil, match, nil
}
outputFields = make([]string, 0, l)
for _, entry := range fields {
if entry.Star.IsSome() {
// TODO: support `select *`.
return nil, match, fmt.Errorf("* is not supported")
}
if entry.FunctionCall.IsSome() {
return nil, match, fmt.Errorf("combined select elements is not supported")
}
if entry.FullColumnName.IsSome() {
c := entry.FullColumnName.Unwrap()
if c.Alias.IsSome() {
return nil, match, fmt.Errorf("alias for select elements is not supported")
}
outputFields = append(outputFields, c.Name)
}
}
return outputFields, false, nil
}