milvus/internal/mysqld/executor/executor_test.go

605 lines
20 KiB
Go

package executor
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/mysqld/parser/antlrparser"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/mysqld/planner"
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
)
func Test_defaultExecutor_Run(t *testing.T) {
t.Run("not sql statements", func(t *testing.T) {
e := NewDefaultExecutor(nil)
plan := &planner.PhysicalPlan{
Node: planner.NewNodeConstant("20230306", planner.WithStringLiteral("20230306")),
}
_, err := e.Run(context.TODO(), plan)
assert.Error(t, err)
})
t.Run("multiple statements", func(t *testing.T) {
e := NewDefaultExecutor(nil)
stmts := []*planner.NodeSqlStatement{
planner.NewNodeSqlStatement("sql1"),
planner.NewNodeSqlStatement("sql2"),
}
plan := &planner.PhysicalPlan{
Node: planner.NewNodeSqlStatements("sql1; sql2", stmts),
}
_, err := e.Run(context.TODO(), plan)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
res := &milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "field",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
CollectionName: "test",
}
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.QueryRequest
).Return(res, nil)
e := NewDefaultExecutor(s)
stmts := []*planner.NodeSqlStatement{
planner.NewNodeSqlStatement("", planner.WithDmlStatement(
planner.NewNodeDmlStatement("", planner.WithSelectStatement(
planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field"))),
}, planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
},
planner.WithWhere(planner.NewNodeExpression("", planner.WithPredicate(
planner.NewNodePredicate("", planner.WithNodeBinaryComparisonPredicate(
planner.NewNodeBinaryComparisonPredicate("",
planner.NewNodePredicate("", planner.WithNodeExpressionAtomPredicate(
planner.NewNodeExpressionAtomPredicate("",
planner.NewNodeExpressionAtom("", planner.ExpressionAtomWithFullColumnName(
planner.NewNodeFullColumnName("", "field")))))),
planner.NewNodePredicate("", planner.WithNodeExpressionAtomPredicate(
planner.NewNodeExpressionAtomPredicate("",
planner.NewNodeExpressionAtom("", planner.ExpressionAtomWithFullColumnName(
planner.NewNodeFullColumnName("", "field")))))),
planner.ComparisonOperatorEqual),
),
)),
),
),
),
),
),
)),
)),
)),
)),
}
plan := &planner.PhysicalPlan{
Node: planner.NewNodeSqlStatements("", stmts),
}
sqlRes, err := e.Run(context.TODO(), plan)
assert.NoError(t, err)
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, 4, len(sqlRes.Rows))
assert.Equal(t, "field", sqlRes.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
})
}
func Test_defaultExecutor_dispatch(t *testing.T) {
t.Run("not dml", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSqlStatement("")
_, err := e.dispatch(context.TODO(), n)
assert.Error(t, err)
})
}
func Test_defaultExecutor_dispatchDmlStatement(t *testing.T) {
t.Run("not select", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeDmlStatement("")
_, err := e.dispatchDmlStatement(context.TODO(), n)
assert.Error(t, err)
})
}
func Test_defaultExecutor_execSelect(t *testing.T) {
t.Run("not simple select", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("")
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("lock clause, not supported", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithLockClause(
planner.NewNodeLockClause("", planner.LockClauseOptionForUpdate)))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("not query", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("")))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("pagination, not supported", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
nil,
planner.WithLimit(planner.NewNodeLimitClause("", 1, 2)))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("select specs, not supported", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
[]*planner.NodeSelectSpec{
planner.NewNodeSelectSpec(""),
},
nil)))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("no from clause", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
nil)))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("no table source", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
nil,
planner.WithFrom(planner.NewNodeFromClause("", nil)))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("multiple table source", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
nil,
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource(""),
planner.NewNodeTableSource(""),
})))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("target entry as alias, not supported", func(t *testing.T) {
e := NewDefaultExecutor(nil).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
})))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("failed to execute count", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("GetCollectionStatistics",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.GetCollectionStatisticsRequest
).Return(nil, errors.New("error mock GetCollectionStatistics"))
e := NewDefaultExecutor(s).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFunctionCall(
planner.NewNodeFunctionCall("", planner.WithAgg(
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
planner.NewNodeCount(""))))))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
})))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("count without filter", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("GetCollectionStatistics",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.GetCollectionStatisticsRequest
).Return(&milvuspb.GetCollectionStatisticsResponse{
Status: &commonpb.Status{},
Stats: []*commonpb.KeyValuePair{{Value: "2"}},
}, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFunctionCall(
planner.NewNodeFunctionCall("", planner.WithAgg(
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
planner.NewNodeCount(""))))))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
})))))))
res, err := e.execSelect(context.TODO(), n)
assert.NoError(t, err)
assert.Equal(t, 1, len(res.Fields))
assert.Equal(t, "count(*)", res.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, res.Fields[0].Type)
assert.Equal(t, 1, len(res.Rows))
assert.Equal(t, 1, len(res.Rows[0]))
assert.Equal(t, querypb.Type_INT64, res.Rows[0][0].Type())
})
t.Run("count with filter", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.QueryRequest
).Return(&milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
}, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFunctionCall(
planner.NewNodeFunctionCall("", planner.WithAgg(
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
planner.NewNodeCount(""))))))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
},
planner.WithWhere(GenNodeExpression("field", 1, planner.ComparisonOperatorGreaterEqual)))))))))
res, err := e.execSelect(context.TODO(), n)
assert.NoError(t, err)
assert.Equal(t, 1, len(res.Fields))
assert.Equal(t, "count(*)", res.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, res.Fields[0].Type)
assert.Equal(t, 1, len(res.Rows))
assert.Equal(t, 1, len(res.Rows[0]))
assert.Equal(t, querypb.Type_INT64, res.Rows[0][0].Type())
})
t.Run("query without filter", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
e := NewDefaultExecutor(s).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field"))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
})))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("failed to query with filter", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.QueryRequest
).Return(nil, errors.New("error mock Query"))
e := NewDefaultExecutor(s).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field"))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
},
planner.WithWhere(GenNodeExpression("field", 100, planner.ComparisonOperatorEqual)))))))))
_, err := e.execSelect(context.TODO(), n)
assert.Error(t, err)
})
t.Run("query with filter", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
res := &milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "field",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
CollectionName: "test",
}
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.QueryRequest
).Return(res, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
n := planner.NewNodeSelectStatement("", planner.WithSimpleSelect(
planner.NewNodeSimpleSelect("", planner.WithQuery(
planner.NewNodeQuerySpecification("",
nil,
[]*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field"))),
},
planner.WithFrom(planner.NewNodeFromClause("",
[]*planner.NodeTableSource{
planner.NewNodeTableSource("", planner.WithTableName("test")),
},
planner.WithWhere(GenNodeExpression("field", 100, planner.ComparisonOperatorEqual)))))))))
sqlRes, err := e.execSelect(context.TODO(), n)
assert.NoError(t, err)
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, 4, len(sqlRes.Rows))
assert.Equal(t, "field", sqlRes.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
})
}
func Test_defaultExecutor_execCountWithFilter(t *testing.T) {
t.Run("failed to query", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(nil, errors.New("error mock Query"))
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execCountWithFilter(context.TODO(), "t", "a > 2")
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
res := &milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "field",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
CollectionName: "test",
}
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.QueryRequest
).Return(res, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
sqlRes, err := e.execCountWithFilter(context.TODO(), "t", "a > 2")
assert.NoError(t, err)
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows))
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
})
}
func Test_defaultExecutor_execQuery(t *testing.T) {
t.Run("rpc failure", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(nil, errors.New("error mock Query"))
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"})
assert.Error(t, err)
})
t.Run("not success", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(&milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "error mock reason",
},
}, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"})
assert.Error(t, err)
})
t.Run("success", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(&milvuspb.QueryResults{
Status: &commonpb.Status{},
}, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"})
assert.NoError(t, err)
})
}
func Test_defaultExecutor_execANNS(t *testing.T) {
f1 := &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: "pk",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{6, 5, 4, 3, 2, 1},
},
},
},
},
}
f2 := &schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: "random",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: []float32{6.6, 5.5, 4.4, 3.3, 2.2, 1.1},
},
},
},
},
}
res := &milvuspb.SearchResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Results: &schemapb.SearchResultData{
NumQueries: 2,
TopK: 3,
FieldsData: []*schemapb.FieldData{f1, f2},
Scores: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6},
Topks: []int64{2, 2},
},
CollectionName: "hello_milvus",
}
s := mocks.NewProxyComponent(t)
s.On("Search",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.SearchRequest
).Return(res, nil)
sql := `
select
$query_number, pk, random, $distance
from hello_milvus
where random > 0.5
anns by embeddings -> ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], [0.8, 0.7, 0.6, 0.6, 0.4, 0.3, 0.2, 0.1])
params = (metric_type=L2, nprobe=10)
limit 3
`
plan, warns, err := antlrparser.NewAntlrParser().Parse(sql)
assert.NoError(t, err)
assert.Nil(t, warns)
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err = e.execANNS(context.TODO(),
antlrparser.GetSqlStatements(plan.Node).Statements[0].DmlStatement.Unwrap().SelectStatement.Unwrap().SimpleSelect.Unwrap().Query.Unwrap(),
[]string{"$query_number", "pk", "random", "$distance"})
assert.NoError(t, err)
}