Support to parse ANNS plan node (#23463)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/23506/head
Jiquan Long 2023-04-18 10:06:32 +08:00 committed by GitHub
parent a455595c9b
commit c013492762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 411 additions and 5 deletions

View File

@ -4,6 +4,8 @@ import (
"fmt"
"strconv"
"github.com/milvus-io/milvus/internal/mysqld/sqlutil"
"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
@ -172,6 +174,17 @@ func (v *AstBuilder) VisitQuerySpecification(ctx *parsergen.QuerySpecificationCo
}
}
annsCtx := ctx.AnnsClause()
if annsCtx != nil {
r := annsCtx.Accept(v)
if err := GetError(r); err != nil {
return err
}
if n := GetANNSClause(r); n != nil {
opts = append(opts, planner.WithANNS(n))
}
}
limitCtx := ctx.LimitClause()
if limitCtx != nil {
r := limitCtx.Accept(v)
@ -290,6 +303,167 @@ func (v *AstBuilder) VisitFromClause(ctx *parsergen.FromClauseContext) interface
return planner.NewNodeFromClause(text, tableSources, opts...)
}
func (v *AstBuilder) VisitAnnsClause(ctx *parsergen.AnnsClauseContext) interface{} {
text := GetOriginalText(ctx)
var column *planner.NodeFullColumnName
var vectors []*planner.NodeVector
columnCtx := ctx.FullColumnName()
if columnCtx != nil {
r := columnCtx.Accept(v)
if err := GetError(r); err != nil {
return err
}
column = planner.NewNodeFullColumnName(GetOriginalText(columnCtx), r.(string))
}
vectorsCtx := ctx.AnnsVectors()
if vectorsCtx != nil {
r := vectorsCtx.Accept(v)
if err := GetError(r); err != nil {
return err
}
vectors = r.([]*planner.NodeVector)
}
var opts []planner.NodeANNSClauseOption
var paramsCtx parsergen.IAnnsParamsClauseContext
// Don't use ctx.AnnsParamsClause() directly, especially when the nq is too large.
// In fact, ctx.AnnsParamsClause() will iterate all children from the beginning index, which
// is not very efficient.
children := ctx.GetChildren()
lenOfChildren := len(children)
for i := lenOfChildren - 1; i >= 0; i-- {
if childCtx, ok := children[i].(parsergen.IAnnsParamsClauseContext); ok {
paramsCtx = childCtx
break
}
}
if paramsCtx != nil {
r := paramsCtx.Accept(v)
if err := GetError(r); err != nil {
return err
}
if n := GetKVPairs(r); n != nil {
opts = append(opts, planner.NodeANNSClauseWithParams(n))
}
}
return planner.NewNodeANNSClause(text, column, vectors, opts...)
}
func (v *AstBuilder) VisitAnnsVectors(ctx *parsergen.AnnsVectorsContext) interface{} {
var vectors []*planner.NodeVector
allVectorsCtx := ctx.AllAnnsVector()
for _, vectorCtx := range allVectorsCtx {
r := vectorCtx.Accept(v)
if err := GetError(r); err != nil {
return err
}
if n := GetVector(r); n != nil {
vectors = append(vectors, n)
}
}
return vectors
}
func (v *AstBuilder) VisitAnnsVector(ctx *parsergen.AnnsVectorContext) interface{} {
if ctx.BIT_STRING() != nil {
return fmt.Errorf("binary vector is not supported")
}
var floatArray []float32
floatArrayCtx := ctx.FloatArray()
if floatArrayCtx != nil {
r := floatArrayCtx.Accept(v)
if err := GetError(r); err != nil {
return err
}
floatArray = r.([]float32)
}
return planner.NewNodeVector(planner.WithFloatVector(planner.NewNodeFloatVector(floatArray)))
}
func (v *AstBuilder) VisitFloatArray(ctx *parsergen.FloatArrayContext) interface{} {
var floatArray []float32
allDecimalCtx := ctx.AllDecimalLiteral()
for _, childCtx := range allDecimalCtx {
r := childCtx.Accept(v)
switch rWithType := r.(type) {
case int64:
floatArray = append(floatArray, float32(rWithType))
case float32:
floatArray = append(floatArray, rWithType)
case float64:
floatArray = append(floatArray, float32(rWithType))
case error:
return rWithType
default:
// TODO
return fmt.Errorf("failed to parse float vector: %s", GetOriginalText(childCtx))
}
}
return floatArray
}
func (v *AstBuilder) VisitAnnsParamsClause(ctx *parsergen.AnnsParamsClauseContext) interface{} {
return ctx.KvPairs().Accept(v)
}
func (v *AstBuilder) VisitKvPairs(ctx *parsergen.KvPairsContext) interface{} {
allKvPairs := ctx.AllKvPair()
lenOfPairs := len(allKvPairs)
if lenOfPairs == 0 {
return nil
}
pairs := planner.NewNodeKVPairs()
for _, child := range allKvPairs {
childCtx := child.(*parsergen.KvPairContext)
key := childCtx.ID().GetText()
value := childCtx.Value().Accept(v)
switch tv := value.(type) {
case string:
pairs.Insert(key, tv)
case int:
pairs.Insert(key, strconv.Itoa(tv))
case int32:
pairs.Insert(key, strconv.Itoa(int(tv)))
case int64:
pairs.Insert(key, strconv.Itoa(int(tv)))
case float32:
pairs.Insert(key, sqlutil.Float32ToString(tv))
case float64:
pairs.Insert(key, sqlutil.Float64ToString(tv))
default:
return fmt.Errorf("invalid type: %s", GetOriginalText(childCtx))
}
}
return pairs
}
func (v *AstBuilder) VisitValue(ctx *parsergen.ValueContext) interface{} {
if idCtx := ctx.ID(); idCtx != nil {
return idCtx.GetText()
}
return ctx.Constant().Accept(v)
}
func (v *AstBuilder) VisitTableSources(ctx *parsergen.TableSourcesContext) interface{} {
// Should not be visited.
return nil

View File

@ -77,6 +77,14 @@ func GetFromClause(obj interface{}) *planner.NodeFromClause {
return n
}
func GetANNSClause(obj interface{}) *planner.NodeANNSClause {
n, ok := obj.(*planner.NodeANNSClause)
if !ok {
return nil
}
return n
}
func GetLimitClause(obj interface{}) *planner.NodeLimitClause {
n, ok := obj.(*planner.NodeLimitClause)
if !ok {
@ -140,3 +148,19 @@ func GetExpressions(obj interface{}) *planner.NodeExpressions {
}
return n
}
func GetKVPairs(obj interface{}) *planner.NodeKVPairs {
n, ok := obj.(*planner.NodeKVPairs)
if !ok {
return nil
}
return n
}
func GetVector(obj interface{}) *planner.NodeVector {
n, ok := obj.(*planner.NodeVector)
if !ok {
return nil
}
return n
}

View File

@ -4,9 +4,10 @@ import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mysqld/parser"
"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
func Test_antlrParser_Parse(t *testing.T) {
@ -48,3 +49,35 @@ func Test_antlrParser_Parse(t *testing.T) {
debug(t, sql)
}
}
type ANNSSuite struct {
suite.Suite
p parser.Parser
}
func (suite *ANNSSuite) SetupTest() {
suite.p = NewAntlrParser()
}
func (suite *ANNSSuite) TearDownTest() {}
func TestANNSSuite(t *testing.T) {
suite.Run(t, new(ANNSSuite))
}
func (suite *ANNSSuite) TestFloatVector() {
sql := `
select query_number, id, distance
from t
where id >= 1000 and id <= 10000
anns by feature -> ([0.23, 0.21], [0.24, 0.26]) PARAMS = (nprobe=1, ef=5)
limit 100
`
plan, warns, err := suite.p.Parse(sql)
suite.NoError(err)
suite.Nil(warns)
planner.NewTreeUtils().PrettyPrintHrn(GetSqlStatements(plan.Node))
}

View File

@ -0,0 +1,55 @@
package planner
import (
"fmt"
"github.com/moznion/go-optional"
)
type NodeANNSClause struct {
baseNode
Column *NodeFullColumnName
Vectors []*NodeVector
Params optional.Option[*NodeKVPairs]
}
func (n *NodeANNSClause) String() string {
s := fmt.Sprintf("NodeANNSClause, Column: %s, Nq: %d", n.Column.String(), len(n.Vectors))
if n.Params.IsSome() {
s += fmt.Sprintf(", Params: %s", n.Params.Unwrap().String())
}
return s
}
func (n *NodeANNSClause) GetChildren() []Node {
// return []Node{n.Column}
return nil
}
func (n *NodeANNSClause) Accept(v Visitor) interface{} {
return v.VisitANNSClause(n)
}
type NodeANNSClauseOption func(*NodeANNSClause)
func NodeANNSClauseWithParams(p *NodeKVPairs) NodeANNSClauseOption {
return func(n *NodeANNSClause) {
n.Params = optional.Some(p)
}
}
func (n *NodeANNSClause) apply(opts ...NodeANNSClauseOption) {
for _, opt := range opts {
opt(n)
}
}
func NewNodeANNSClause(text string, column *NodeFullColumnName, vectors []*NodeVector, opts ...NodeANNSClauseOption) *NodeANNSClause {
n := &NodeANNSClause{
baseNode: newBaseNode(text),
Column: column,
Vectors: vectors,
}
n.apply(opts...)
return n
}

View File

@ -0,0 +1,6 @@
package planner
type NodeBinaryVector struct {
baseNode
//TODO
}

View File

@ -0,0 +1,11 @@
package planner
type NodeFloatVector struct {
Array []float32
}
func NewNodeFloatVector(arr []float32) *NodeFloatVector {
return &NodeFloatVector{
Array: arr,
}
}

View File

@ -0,0 +1,23 @@
package planner
import "encoding/json"
type NodeKVPairs struct {
KVs map[string]string
}
func (n *NodeKVPairs) Insert(key, value string) {
n.KVs[key] = value
}
func (n *NodeKVPairs) String() string {
// How could `Marshal` return error here?
bs, _ := json.Marshal(n.KVs)
return string(bs)
}
func NewNodeKVPairs() *NodeKVPairs {
return &NodeKVPairs{
KVs: make(map[string]string),
}
}

View File

@ -7,6 +7,7 @@ type NodeQuerySpecification struct {
SelectSpecs []*NodeSelectSpec
SelectElements []*NodeSelectElement
From optional.Option[*NodeFromClause]
Anns optional.Option[*NodeANNSClause]
Limit optional.Option[*NodeLimitClause]
}
@ -16,18 +17,27 @@ func (n *NodeQuerySpecification) String() string {
func (n *NodeQuerySpecification) GetChildren() []Node {
children := make([]Node, 0, len(n.SelectSpecs)+len(n.SelectElements)+2)
for _, child := range n.SelectSpecs {
children = append(children, child)
}
for _, child := range n.SelectElements {
children = append(children, child)
}
if n.From.IsSome() {
children = append(children, n.From.Unwrap())
}
if n.Anns.IsSome() {
children = append(children, n.Anns.Unwrap())
}
if n.Limit.IsSome() {
children = append(children, n.Limit.Unwrap())
}
return children
}
@ -49,6 +59,12 @@ func WithFrom(from *NodeFromClause) NodeQuerySpecificationOption {
}
}
func WithANNS(anns *NodeANNSClause) NodeQuerySpecificationOption {
return func(n *NodeQuerySpecification) {
n.Anns = optional.Some(anns)
}
}
func WithLimit(Limit *NodeLimitClause) NodeQuerySpecificationOption {
return func(n *NodeQuerySpecification) {
n.Limit = optional.Some(Limit)

View File

@ -0,0 +1,27 @@
package planner
import "github.com/moznion/go-optional"
type NodeVector struct {
FloatVector optional.Option[*NodeFloatVector]
}
type NodeVectorOption func(*NodeVector)
func (n *NodeVector) apply(opts ...NodeVectorOption) {
for _, opt := range opts {
opt(n)
}
}
func WithFloatVector(v *NodeFloatVector) NodeVectorOption {
return func(n *NodeVector) {
n.FloatVector = optional.Some(v)
}
}
func NewNodeVector(opts ...NodeVectorOption) *NodeVector {
n := &NodeVector{}
n.apply(opts...)
return n
}

View File

@ -31,4 +31,14 @@ type Visitor interface {
VisitUnaryExpressionAtom(n *NodeUnaryExpressionAtom) interface{}
VisitNestedExpressionAtom(n *NodeNestedExpressionAtom) interface{}
VisitConstant(n *NodeConstant) interface{}
/*
// In fact, these structs are not enough to be a node.
// They themselves alone don't make any sense. Just regard them as parameters.
VisitFloatVector(n *NodeFloatVector) interface{}
VisitVector(n *NodeVector) interface{}
VisitKVPairs(n *NodeKVPairs) interface{}
*/
VisitANNSClause(*NodeANNSClause) interface{}
}

View File

@ -3,6 +3,8 @@ package planner
import (
"fmt"
"strconv"
"github.com/milvus-io/milvus/internal/mysqld/sqlutil"
)
// TODO: remove this after execution engine is ready.
@ -120,7 +122,7 @@ func (v *exprTextRestorer) VisitConstant(n *NodeConstant) interface{} {
return strconv.FormatBool(n.BooleanLiteral.Unwrap())
}
if n.RealLiteral.IsSome() {
return strconv.FormatFloat(n.RealLiteral.Unwrap(), 'g', 10, 14)
return sqlutil.Float64ToString(n.RealLiteral.Unwrap())
}
return ""
}

View File

@ -1,6 +1,10 @@
package planner
import "strconv"
import (
"strconv"
"github.com/milvus-io/milvus/internal/mysqld/sqlutil"
)
type jsonVisitor struct {
}
@ -362,12 +366,22 @@ func (v jsonVisitor) VisitConstant(n *NodeConstant) interface{} {
}
if n.RealLiteral.IsSome() {
j["real_literal"] = strconv.FormatFloat(n.RealLiteral.Unwrap(), 'f', -1, 64)
j["real_literal"] = sqlutil.Float64ToString(n.RealLiteral.Unwrap())
}
return j
}
func (v jsonVisitor) VisitANNSClause(n *NodeANNSClause) interface{} {
// leaf node.
j := map[string]interface{}{}
j["anns"] = n.String()
return j
}
func NewJSONVisitor() Visitor {
return &jsonVisitor{}
}

View File

@ -0,0 +1,11 @@
package sqlutil
import "strconv"
func Float32ToString(f float32) string {
return strconv.FormatFloat(float64(f), 'f', -1, 32)
}
func Float64ToString(f float64) string {
return strconv.FormatFloat(f, 'f', -1, 64)
}