milvus/internal/mysqld/executor/executor.go

230 lines
6.5 KiB
Go

package executor
import (
"context"
"fmt"
"strconv"
"github.com/milvus-io/milvus/pkg/common"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/mysqld/parser/antlrparser"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/milvus-io/milvus/internal/types"
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
)
type Executor interface {
Run(ctx context.Context, plan *planner.PhysicalPlan) (*sqltypes.Result, error)
}
// defaultExecutor only translates sql to rpc. TODO: Better to use vacalno model or batch model.
type defaultExecutor struct {
s types.ProxyComponent
}
func (e *defaultExecutor) Run(ctx context.Context, plan *planner.PhysicalPlan) (*sqltypes.Result, error) {
statements := antlrparser.GetSqlStatements(plan.Node)
if statements == nil {
return nil, fmt.Errorf("invalid node, sql should be parsed to statements")
}
l := len(statements.Statements)
if l != 1 {
return nil, fmt.Errorf("only one statement is supported")
}
return e.dispatch(ctx, statements.Statements[0])
}
func (e *defaultExecutor) dispatch(ctx context.Context, n *planner.NodeSqlStatement) (*sqltypes.Result, error) {
if n.DmlStatement.IsSome() {
return e.dispatchDmlStatement(ctx, n.DmlStatement.Unwrap())
}
return nil, fmt.Errorf("invalid sql statement, only dml statement is supported")
}
func (e *defaultExecutor) dispatchDmlStatement(ctx context.Context, n *planner.NodeDmlStatement) (*sqltypes.Result, error) {
if n.SelectStatement.IsSome() {
return e.execSelect(ctx, n.SelectStatement.Unwrap())
}
return nil, fmt.Errorf("invalid dml statement, only select statement is supported")
}
func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectStatement) (*sqltypes.Result, error) {
if !n.SimpleSelect.IsSome() {
return nil, fmt.Errorf("invalid select statement, only simple select is supported")
}
stmt := n.SimpleSelect.Unwrap()
if stmt.LockClause.IsSome() {
return nil, fmt.Errorf("invalid simple select statement, lock clause is not supported")
}
if !stmt.Query.IsSome() {
return nil, fmt.Errorf("invalid simple select statement, only query is supported")
}
q := stmt.Query.Unwrap()
if len(q.SelectSpecs) != 0 {
return nil, fmt.Errorf("invalid query statement, select spec is not supported")
}
if !q.From.IsSome() {
return nil, fmt.Errorf("invalid query statement, table source is not specified")
}
from := q.From.Unwrap()
if len(from.TableSources) != 1 {
return nil, fmt.Errorf("invalid query statement, only one table source is supported")
}
tableName := from.TableSources[0].TableName.Unwrap()
outputFields, match, err := getOutputFieldsOrMatchCountRule(q.SelectElements)
if err != nil {
return nil, err
}
if match && !from.Where.IsSome() { // count without filter.
rowCnt, err := e.execCountWithoutFilter(ctx, tableName)
if err != nil {
return nil, err
}
result1 := wrapCountResult(rowCnt, "count(*)")
return result1, nil
}
if match && from.Where.IsSome() { // count with filter.
filter := planner.NewExprTextRestorer().RestoreExprText(from.Where.Unwrap())
return e.execCountWithFilter(ctx, tableName, filter)
}
// `match` is false.
if q.Anns.IsSome() {
// reuse the parsed `outputFields`.
return e.execANNS(ctx, q, outputFields)
}
if q.Limit.IsSome() {
// TODO: use pagination.
return nil, fmt.Errorf("invalid query statement, limit/offset is not supported")
}
if !from.Where.IsSome() { // query without filter.
return nil, fmt.Errorf("query without filter is not supported")
}
filter := planner.NewExprTextRestorer().RestoreExprText(from.Where.Unwrap())
res, err := e.execQuery(ctx, tableName, filter, outputFields)
if err != nil {
return nil, err
}
return wrapQueryResults(res), nil
}
func (e *defaultExecutor) execANNS(ctx context.Context, q *planner.NodeQuerySpecification, outputs []string) (*sqltypes.Result, error) {
if !q.Limit.IsSome() {
return nil, fmt.Errorf("limit not specified in the ANNS statement")
}
annsClause := q.Anns.Unwrap()
searchParams := prepareSearchParams(q.Limit.Unwrap(), annsClause)
outputsIndex, userOutputs := generateOutputsIndex(outputs)
filter := restoreExpr(q.From.Unwrap())
tableName := q.From.Unwrap().TableSources[0].TableName.Unwrap()
req := prepareSearchReq(tableName, filter, annsClause.Vectors, userOutputs, searchParams)
res, err := e.s.Search(ctx, req)
if err != nil {
return nil, err
}
if res.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return nil, common.NewStatusError(res.GetStatus().GetErrorCode(), res.GetStatus().GetReason())
}
return wrapSearchResult(res, outputsIndex, userOutputs), nil
}
func (e *defaultExecutor) execCountWithFilter(ctx context.Context, tableName string, filter string) (*sqltypes.Result, error) {
// TODO: check if `*` match vector field.
outputs := []string{"*"}
res, err := e.execQuery(ctx, tableName, filter, outputs)
if err != nil {
return nil, err
}
nColumn := len(res.GetFieldsData())
nRow := 0
if nColumn > 0 {
nRow = typeutil.GetRowCount(res.GetFieldsData()[0])
}
return wrapCountResult(nRow, "count(*)"), nil
}
func (e *defaultExecutor) execCountWithoutFilter(ctx context.Context, tableName string) (int, error) {
req := &milvuspb.GetCollectionStatisticsRequest{
Base: commonpbutil.NewMsgBase(),
CollectionName: tableName,
}
resp, err := e.s.GetCollectionStatistics(ctx, req)
if err != nil {
return 0, err
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return 0, errors.New(resp.GetStatus().GetReason())
}
rowCnt, err := strconv.Atoi(resp.GetStats()[0].GetValue())
if err != nil {
return 0, err
}
return rowCnt, nil
}
func (e *defaultExecutor) execQuery(ctx context.Context, tableName string, filter string, outputs []string) (*milvuspb.QueryResults, error) {
req := &milvuspb.QueryRequest{
Base: commonpbutil.NewMsgBase(),
DbName: "",
CollectionName: tableName,
Expr: filter,
OutputFields: outputs,
PartitionNames: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
QueryParams: nil,
}
resp, err := e.s.Query(ctx, req)
if err != nil {
return nil, err
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return nil, errors.New(resp.GetStatus().GetReason())
}
return resp, nil
}
func NewDefaultExecutor(s types.ProxyComponent) Executor {
return &defaultExecutor{s: s}
}