mirror of https://github.com/milvus-io/milvus.git
216 lines
5.8 KiB
Go
216 lines
5.8 KiB
Go
package executor
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
|
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
|
|
)
|
|
|
|
func GenNodeExpression(field string, c int64, op planner.ComparisonOperator) *planner.NodeExpression {
|
|
n := planner.NewNodeExpression("", planner.WithPredicate(
|
|
planner.NewNodePredicate("", planner.WithNodeBinaryComparisonPredicate(
|
|
planner.NewNodeBinaryComparisonPredicate("",
|
|
planner.NewNodePredicate("", planner.WithNodeExpressionAtomPredicate(
|
|
planner.NewNodeExpressionAtomPredicate("",
|
|
planner.NewNodeExpressionAtom("", planner.ExpressionAtomWithFullColumnName(
|
|
planner.NewNodeFullColumnName("", field)))))),
|
|
planner.NewNodePredicate("", planner.WithNodeExpressionAtomPredicate(
|
|
planner.NewNodeExpressionAtomPredicate("",
|
|
planner.NewNodeExpressionAtom("", planner.ExpressionAtomWithConstant(
|
|
planner.NewNodeConstant("", planner.WithDecimalLiteral(c))))))),
|
|
op),
|
|
),
|
|
)),
|
|
)
|
|
return n
|
|
}
|
|
|
|
func Test_getOutputFieldsOrMatchCountRule(t *testing.T) {
|
|
t.Run("match count rule", func(t *testing.T) {
|
|
fl := []*planner.NodeSelectElement{
|
|
planner.NewNodeSelectElement("", planner.WithFunctionCall(
|
|
planner.NewNodeFunctionCall("", planner.WithAgg(
|
|
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
|
|
planner.NewNodeCount(""))))))),
|
|
}
|
|
_, match, err := getOutputFieldsOrMatchCountRule(fl)
|
|
assert.NoError(t, err)
|
|
assert.True(t, match)
|
|
})
|
|
|
|
t.Run("star *, not supported", func(t *testing.T) {
|
|
fl := []*planner.NodeSelectElement{
|
|
planner.NewNodeSelectElement("", planner.WithStar()),
|
|
}
|
|
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("combined", func(t *testing.T) {
|
|
fl := []*planner.NodeSelectElement{
|
|
planner.NewNodeSelectElement("", planner.WithFunctionCall(
|
|
planner.NewNodeFunctionCall("", planner.WithAgg(
|
|
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
|
|
planner.NewNodeCount(""))))))),
|
|
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
|
planner.NewNodeFullColumnName("", "field"))),
|
|
}
|
|
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("alias, not supported", func(t *testing.T) {
|
|
fl := []*planner.NodeSelectElement{
|
|
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
|
planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))),
|
|
}
|
|
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("normal case", func(t *testing.T) {
|
|
fl := []*planner.NodeSelectElement{
|
|
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
|
planner.NewNodeFullColumnName("", "field1"))),
|
|
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
|
planner.NewNodeFullColumnName("", "field2"))),
|
|
}
|
|
outputFields, match, err := getOutputFieldsOrMatchCountRule(fl)
|
|
assert.NoError(t, err)
|
|
assert.False(t, match)
|
|
assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields)
|
|
})
|
|
}
|
|
|
|
func Test_wrapCountResult(t *testing.T) {
|
|
sqlRes := wrapCountResult(100, "count(*)")
|
|
assert.Equal(t, 1, len(sqlRes.Fields))
|
|
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
|
|
assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type)
|
|
assert.Equal(t, 1, len(sqlRes.Rows))
|
|
assert.Equal(t, 1, len(sqlRes.Rows[0]))
|
|
assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type())
|
|
}
|
|
|
|
func Test_wrapQueryResults(t *testing.T) {
|
|
res := &milvuspb.QueryResults{
|
|
Status: &commonpb.Status{},
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldName: "field",
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3, 4},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
CollectionName: "test",
|
|
}
|
|
sqlRes := wrapQueryResults(res)
|
|
assert.Equal(t, 1, len(sqlRes.Fields))
|
|
assert.Equal(t, 4, len(sqlRes.Rows))
|
|
assert.Equal(t, "field", sqlRes.Fields[0].Name)
|
|
assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type)
|
|
assert.Equal(t, 1, len(sqlRes.Rows[0]))
|
|
assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type())
|
|
}
|
|
|
|
func Test_getSQLField(t *testing.T) {
|
|
f := &schemapb.FieldData{
|
|
FieldName: "a",
|
|
Type: schemapb.DataType_Int64,
|
|
}
|
|
sf := getSQLField("t", f)
|
|
assert.Equal(t, "a", sf.Name)
|
|
assert.Equal(t, query.Type_INT64, sf.Type)
|
|
assert.Equal(t, "t", sf.Table)
|
|
}
|
|
|
|
func Test_toSQLType(t *testing.T) {
|
|
type args struct {
|
|
t schemapb.DataType
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want query.Type
|
|
}{
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Bool,
|
|
},
|
|
want: query.Type_UINT8,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Int8,
|
|
},
|
|
want: query.Type_INT8,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Int16,
|
|
},
|
|
want: query.Type_INT16,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Int32,
|
|
},
|
|
want: query.Type_INT32,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Int64,
|
|
},
|
|
want: query.Type_INT64,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Float,
|
|
},
|
|
want: query.Type_FLOAT32,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_Double,
|
|
},
|
|
want: query.Type_FLOAT64,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_VarChar,
|
|
},
|
|
want: query.Type_VARCHAR,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_FloatVector,
|
|
},
|
|
want: query.Type_NULL_TYPE,
|
|
},
|
|
{
|
|
args: args{
|
|
t: schemapb.DataType_BinaryVector,
|
|
},
|
|
want: query.Type_NULL_TYPE,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t)
|
|
})
|
|
}
|
|
}
|