mirror of https://github.com/milvus-io/milvus.git
Support to parse ANNS plan node (#23463)
Signed-off-by: longjiquan <jiquan.long@zilliz.com>pull/23506/head
parent
a455595c9b
commit
c013492762
|
@ -4,6 +4,8 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/sqlutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
||||
|
||||
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
|
||||
|
@ -172,6 +174,17 @@ func (v *AstBuilder) VisitQuerySpecification(ctx *parsergen.QuerySpecificationCo
|
|||
}
|
||||
}
|
||||
|
||||
annsCtx := ctx.AnnsClause()
|
||||
if annsCtx != nil {
|
||||
r := annsCtx.Accept(v)
|
||||
if err := GetError(r); err != nil {
|
||||
return err
|
||||
}
|
||||
if n := GetANNSClause(r); n != nil {
|
||||
opts = append(opts, planner.WithANNS(n))
|
||||
}
|
||||
}
|
||||
|
||||
limitCtx := ctx.LimitClause()
|
||||
if limitCtx != nil {
|
||||
r := limitCtx.Accept(v)
|
||||
|
@ -290,6 +303,167 @@ func (v *AstBuilder) VisitFromClause(ctx *parsergen.FromClauseContext) interface
|
|||
return planner.NewNodeFromClause(text, tableSources, opts...)
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitAnnsClause(ctx *parsergen.AnnsClauseContext) interface{} {
|
||||
text := GetOriginalText(ctx)
|
||||
|
||||
var column *planner.NodeFullColumnName
|
||||
var vectors []*planner.NodeVector
|
||||
|
||||
columnCtx := ctx.FullColumnName()
|
||||
if columnCtx != nil {
|
||||
r := columnCtx.Accept(v)
|
||||
if err := GetError(r); err != nil {
|
||||
return err
|
||||
}
|
||||
column = planner.NewNodeFullColumnName(GetOriginalText(columnCtx), r.(string))
|
||||
}
|
||||
|
||||
vectorsCtx := ctx.AnnsVectors()
|
||||
if vectorsCtx != nil {
|
||||
r := vectorsCtx.Accept(v)
|
||||
if err := GetError(r); err != nil {
|
||||
return err
|
||||
}
|
||||
vectors = r.([]*planner.NodeVector)
|
||||
}
|
||||
|
||||
var opts []planner.NodeANNSClauseOption
|
||||
|
||||
var paramsCtx parsergen.IAnnsParamsClauseContext
|
||||
|
||||
// Don't use ctx.AnnsParamsClause() directly, especially when the nq is too large.
|
||||
// In fact, ctx.AnnsParamsClause() will iterate all children from the beginning index, which
|
||||
// is not very efficient.
|
||||
children := ctx.GetChildren()
|
||||
lenOfChildren := len(children)
|
||||
for i := lenOfChildren - 1; i >= 0; i-- {
|
||||
if childCtx, ok := children[i].(parsergen.IAnnsParamsClauseContext); ok {
|
||||
paramsCtx = childCtx
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if paramsCtx != nil {
|
||||
r := paramsCtx.Accept(v)
|
||||
if err := GetError(r); err != nil {
|
||||
return err
|
||||
}
|
||||
if n := GetKVPairs(r); n != nil {
|
||||
opts = append(opts, planner.NodeANNSClauseWithParams(n))
|
||||
}
|
||||
}
|
||||
|
||||
return planner.NewNodeANNSClause(text, column, vectors, opts...)
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitAnnsVectors(ctx *parsergen.AnnsVectorsContext) interface{} {
|
||||
var vectors []*planner.NodeVector
|
||||
|
||||
allVectorsCtx := ctx.AllAnnsVector()
|
||||
|
||||
for _, vectorCtx := range allVectorsCtx {
|
||||
r := vectorCtx.Accept(v)
|
||||
if err := GetError(r); err != nil {
|
||||
return err
|
||||
}
|
||||
if n := GetVector(r); n != nil {
|
||||
vectors = append(vectors, n)
|
||||
}
|
||||
}
|
||||
|
||||
return vectors
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitAnnsVector(ctx *parsergen.AnnsVectorContext) interface{} {
|
||||
if ctx.BIT_STRING() != nil {
|
||||
return fmt.Errorf("binary vector is not supported")
|
||||
}
|
||||
|
||||
var floatArray []float32
|
||||
|
||||
floatArrayCtx := ctx.FloatArray()
|
||||
if floatArrayCtx != nil {
|
||||
r := floatArrayCtx.Accept(v)
|
||||
if err := GetError(r); err != nil {
|
||||
return err
|
||||
}
|
||||
floatArray = r.([]float32)
|
||||
}
|
||||
|
||||
return planner.NewNodeVector(planner.WithFloatVector(planner.NewNodeFloatVector(floatArray)))
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitFloatArray(ctx *parsergen.FloatArrayContext) interface{} {
|
||||
var floatArray []float32
|
||||
|
||||
allDecimalCtx := ctx.AllDecimalLiteral()
|
||||
for _, childCtx := range allDecimalCtx {
|
||||
r := childCtx.Accept(v)
|
||||
switch rWithType := r.(type) {
|
||||
case int64:
|
||||
floatArray = append(floatArray, float32(rWithType))
|
||||
case float32:
|
||||
floatArray = append(floatArray, rWithType)
|
||||
case float64:
|
||||
floatArray = append(floatArray, float32(rWithType))
|
||||
case error:
|
||||
return rWithType
|
||||
default:
|
||||
// TODO
|
||||
return fmt.Errorf("failed to parse float vector: %s", GetOriginalText(childCtx))
|
||||
}
|
||||
}
|
||||
|
||||
return floatArray
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitAnnsParamsClause(ctx *parsergen.AnnsParamsClauseContext) interface{} {
|
||||
return ctx.KvPairs().Accept(v)
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitKvPairs(ctx *parsergen.KvPairsContext) interface{} {
|
||||
allKvPairs := ctx.AllKvPair()
|
||||
lenOfPairs := len(allKvPairs)
|
||||
|
||||
if lenOfPairs == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pairs := planner.NewNodeKVPairs()
|
||||
|
||||
for _, child := range allKvPairs {
|
||||
childCtx := child.(*parsergen.KvPairContext)
|
||||
key := childCtx.ID().GetText()
|
||||
value := childCtx.Value().Accept(v)
|
||||
switch tv := value.(type) {
|
||||
case string:
|
||||
pairs.Insert(key, tv)
|
||||
case int:
|
||||
pairs.Insert(key, strconv.Itoa(tv))
|
||||
case int32:
|
||||
pairs.Insert(key, strconv.Itoa(int(tv)))
|
||||
case int64:
|
||||
pairs.Insert(key, strconv.Itoa(int(tv)))
|
||||
case float32:
|
||||
pairs.Insert(key, sqlutil.Float32ToString(tv))
|
||||
case float64:
|
||||
pairs.Insert(key, sqlutil.Float64ToString(tv))
|
||||
default:
|
||||
return fmt.Errorf("invalid type: %s", GetOriginalText(childCtx))
|
||||
}
|
||||
}
|
||||
|
||||
return pairs
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitValue(ctx *parsergen.ValueContext) interface{} {
|
||||
if idCtx := ctx.ID(); idCtx != nil {
|
||||
return idCtx.GetText()
|
||||
}
|
||||
|
||||
return ctx.Constant().Accept(v)
|
||||
}
|
||||
|
||||
func (v *AstBuilder) VisitTableSources(ctx *parsergen.TableSourcesContext) interface{} {
|
||||
// Should not be visited.
|
||||
return nil
|
||||
|
|
|
@ -77,6 +77,14 @@ func GetFromClause(obj interface{}) *planner.NodeFromClause {
|
|||
return n
|
||||
}
|
||||
|
||||
func GetANNSClause(obj interface{}) *planner.NodeANNSClause {
|
||||
n, ok := obj.(*planner.NodeANNSClause)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func GetLimitClause(obj interface{}) *planner.NodeLimitClause {
|
||||
n, ok := obj.(*planner.NodeLimitClause)
|
||||
if !ok {
|
||||
|
@ -140,3 +148,19 @@ func GetExpressions(obj interface{}) *planner.NodeExpressions {
|
|||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func GetKVPairs(obj interface{}) *planner.NodeKVPairs {
|
||||
n, ok := obj.(*planner.NodeKVPairs)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func GetVector(obj interface{}) *planner.NodeVector {
|
||||
n, ok := obj.(*planner.NodeVector)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/parser"
|
||||
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
func Test_antlrParser_Parse(t *testing.T) {
|
||||
|
@ -48,3 +49,35 @@ func Test_antlrParser_Parse(t *testing.T) {
|
|||
debug(t, sql)
|
||||
}
|
||||
}
|
||||
|
||||
type ANNSSuite struct {
|
||||
suite.Suite
|
||||
|
||||
p parser.Parser
|
||||
}
|
||||
|
||||
func (suite *ANNSSuite) SetupTest() {
|
||||
suite.p = NewAntlrParser()
|
||||
}
|
||||
|
||||
func (suite *ANNSSuite) TearDownTest() {}
|
||||
|
||||
func TestANNSSuite(t *testing.T) {
|
||||
suite.Run(t, new(ANNSSuite))
|
||||
}
|
||||
|
||||
func (suite *ANNSSuite) TestFloatVector() {
|
||||
sql := `
|
||||
select query_number, id, distance
|
||||
from t
|
||||
where id >= 1000 and id <= 10000
|
||||
anns by feature -> ([0.23, 0.21], [0.24, 0.26]) PARAMS = (nprobe=1, ef=5)
|
||||
limit 100
|
||||
`
|
||||
|
||||
plan, warns, err := suite.p.Parse(sql)
|
||||
suite.NoError(err)
|
||||
suite.Nil(warns)
|
||||
|
||||
planner.NewTreeUtils().PrettyPrintHrn(GetSqlStatements(plan.Node))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package planner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/moznion/go-optional"
|
||||
)
|
||||
|
||||
type NodeANNSClause struct {
|
||||
baseNode
|
||||
Column *NodeFullColumnName
|
||||
Vectors []*NodeVector
|
||||
Params optional.Option[*NodeKVPairs]
|
||||
}
|
||||
|
||||
func (n *NodeANNSClause) String() string {
|
||||
s := fmt.Sprintf("NodeANNSClause, Column: %s, Nq: %d", n.Column.String(), len(n.Vectors))
|
||||
if n.Params.IsSome() {
|
||||
s += fmt.Sprintf(", Params: %s", n.Params.Unwrap().String())
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (n *NodeANNSClause) GetChildren() []Node {
|
||||
// return []Node{n.Column}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NodeANNSClause) Accept(v Visitor) interface{} {
|
||||
return v.VisitANNSClause(n)
|
||||
}
|
||||
|
||||
type NodeANNSClauseOption func(*NodeANNSClause)
|
||||
|
||||
func NodeANNSClauseWithParams(p *NodeKVPairs) NodeANNSClauseOption {
|
||||
return func(n *NodeANNSClause) {
|
||||
n.Params = optional.Some(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NodeANNSClause) apply(opts ...NodeANNSClauseOption) {
|
||||
for _, opt := range opts {
|
||||
opt(n)
|
||||
}
|
||||
}
|
||||
|
||||
func NewNodeANNSClause(text string, column *NodeFullColumnName, vectors []*NodeVector, opts ...NodeANNSClauseOption) *NodeANNSClause {
|
||||
n := &NodeANNSClause{
|
||||
baseNode: newBaseNode(text),
|
||||
Column: column,
|
||||
Vectors: vectors,
|
||||
}
|
||||
n.apply(opts...)
|
||||
return n
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package planner
|
||||
|
||||
type NodeBinaryVector struct {
|
||||
baseNode
|
||||
//TODO
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package planner
|
||||
|
||||
type NodeFloatVector struct {
|
||||
Array []float32
|
||||
}
|
||||
|
||||
func NewNodeFloatVector(arr []float32) *NodeFloatVector {
|
||||
return &NodeFloatVector{
|
||||
Array: arr,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package planner
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type NodeKVPairs struct {
|
||||
KVs map[string]string
|
||||
}
|
||||
|
||||
func (n *NodeKVPairs) Insert(key, value string) {
|
||||
n.KVs[key] = value
|
||||
}
|
||||
|
||||
func (n *NodeKVPairs) String() string {
|
||||
// How could `Marshal` return error here?
|
||||
bs, _ := json.Marshal(n.KVs)
|
||||
return string(bs)
|
||||
}
|
||||
|
||||
func NewNodeKVPairs() *NodeKVPairs {
|
||||
return &NodeKVPairs{
|
||||
KVs: make(map[string]string),
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ type NodeQuerySpecification struct {
|
|||
SelectSpecs []*NodeSelectSpec
|
||||
SelectElements []*NodeSelectElement
|
||||
From optional.Option[*NodeFromClause]
|
||||
Anns optional.Option[*NodeANNSClause]
|
||||
Limit optional.Option[*NodeLimitClause]
|
||||
}
|
||||
|
||||
|
@ -16,18 +17,27 @@ func (n *NodeQuerySpecification) String() string {
|
|||
|
||||
func (n *NodeQuerySpecification) GetChildren() []Node {
|
||||
children := make([]Node, 0, len(n.SelectSpecs)+len(n.SelectElements)+2)
|
||||
|
||||
for _, child := range n.SelectSpecs {
|
||||
children = append(children, child)
|
||||
}
|
||||
|
||||
for _, child := range n.SelectElements {
|
||||
children = append(children, child)
|
||||
}
|
||||
|
||||
if n.From.IsSome() {
|
||||
children = append(children, n.From.Unwrap())
|
||||
}
|
||||
|
||||
if n.Anns.IsSome() {
|
||||
children = append(children, n.Anns.Unwrap())
|
||||
}
|
||||
|
||||
if n.Limit.IsSome() {
|
||||
children = append(children, n.Limit.Unwrap())
|
||||
}
|
||||
|
||||
return children
|
||||
}
|
||||
|
||||
|
@ -49,6 +59,12 @@ func WithFrom(from *NodeFromClause) NodeQuerySpecificationOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithANNS(anns *NodeANNSClause) NodeQuerySpecificationOption {
|
||||
return func(n *NodeQuerySpecification) {
|
||||
n.Anns = optional.Some(anns)
|
||||
}
|
||||
}
|
||||
|
||||
func WithLimit(Limit *NodeLimitClause) NodeQuerySpecificationOption {
|
||||
return func(n *NodeQuerySpecification) {
|
||||
n.Limit = optional.Some(Limit)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package planner
|
||||
|
||||
import "github.com/moznion/go-optional"
|
||||
|
||||
type NodeVector struct {
|
||||
FloatVector optional.Option[*NodeFloatVector]
|
||||
}
|
||||
|
||||
type NodeVectorOption func(*NodeVector)
|
||||
|
||||
func (n *NodeVector) apply(opts ...NodeVectorOption) {
|
||||
for _, opt := range opts {
|
||||
opt(n)
|
||||
}
|
||||
}
|
||||
|
||||
func WithFloatVector(v *NodeFloatVector) NodeVectorOption {
|
||||
return func(n *NodeVector) {
|
||||
n.FloatVector = optional.Some(v)
|
||||
}
|
||||
}
|
||||
|
||||
func NewNodeVector(opts ...NodeVectorOption) *NodeVector {
|
||||
n := &NodeVector{}
|
||||
n.apply(opts...)
|
||||
return n
|
||||
}
|
|
@ -31,4 +31,14 @@ type Visitor interface {
|
|||
VisitUnaryExpressionAtom(n *NodeUnaryExpressionAtom) interface{}
|
||||
VisitNestedExpressionAtom(n *NodeNestedExpressionAtom) interface{}
|
||||
VisitConstant(n *NodeConstant) interface{}
|
||||
|
||||
/*
|
||||
// In fact, these structs are not enough to be a node.
|
||||
// They themselves alone don't make any sense. Just regard them as parameters.
|
||||
VisitFloatVector(n *NodeFloatVector) interface{}
|
||||
VisitVector(n *NodeVector) interface{}
|
||||
VisitKVPairs(n *NodeKVPairs) interface{}
|
||||
*/
|
||||
|
||||
VisitANNSClause(*NodeANNSClause) interface{}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package planner
|
|||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/sqlutil"
|
||||
)
|
||||
|
||||
// TODO: remove this after execution engine is ready.
|
||||
|
@ -120,7 +122,7 @@ func (v *exprTextRestorer) VisitConstant(n *NodeConstant) interface{} {
|
|||
return strconv.FormatBool(n.BooleanLiteral.Unwrap())
|
||||
}
|
||||
if n.RealLiteral.IsSome() {
|
||||
return strconv.FormatFloat(n.RealLiteral.Unwrap(), 'g', 10, 14)
|
||||
return sqlutil.Float64ToString(n.RealLiteral.Unwrap())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package planner
|
||||
|
||||
import "strconv"
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/sqlutil"
|
||||
)
|
||||
|
||||
type jsonVisitor struct {
|
||||
}
|
||||
|
@ -362,12 +366,22 @@ func (v jsonVisitor) VisitConstant(n *NodeConstant) interface{} {
|
|||
}
|
||||
|
||||
if n.RealLiteral.IsSome() {
|
||||
j["real_literal"] = strconv.FormatFloat(n.RealLiteral.Unwrap(), 'f', -1, 64)
|
||||
j["real_literal"] = sqlutil.Float64ToString(n.RealLiteral.Unwrap())
|
||||
}
|
||||
|
||||
return j
|
||||
}
|
||||
|
||||
func (v jsonVisitor) VisitANNSClause(n *NodeANNSClause) interface{} {
|
||||
// leaf node.
|
||||
|
||||
j := map[string]interface{}{}
|
||||
|
||||
j["anns"] = n.String()
|
||||
|
||||
return j
|
||||
}
|
||||
|
||||
func NewJSONVisitor() Visitor {
|
||||
return &jsonVisitor{}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
package sqlutil
|
||||
|
||||
import "strconv"
|
||||
|
||||
func Float32ToString(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', -1, 32)
|
||||
}
|
||||
|
||||
func Float64ToString(f float64) string {
|
||||
return strconv.FormatFloat(f, 'f', -1, 64)
|
||||
}
|
Loading…
Reference in New Issue