milvus/internal/mysqld/executor/utils_test.go

216 lines
5.8 KiB
Go

package executor
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/stretchr/testify/assert"
"github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
)
func GenNodeExpression(field string, c int64, op planner.ComparisonOperator) *planner.NodeExpression {
n := planner.NewNodeExpression("", planner.WithPredicate(
planner.NewNodePredicate("", planner.WithNodeBinaryComparisonPredicate(
planner.NewNodeBinaryComparisonPredicate("",
planner.NewNodePredicate("", planner.WithNodeExpressionAtomPredicate(
planner.NewNodeExpressionAtomPredicate("",
planner.NewNodeExpressionAtom("", planner.ExpressionAtomWithFullColumnName(
planner.NewNodeFullColumnName("", field)))))),
planner.NewNodePredicate("", planner.WithNodeExpressionAtomPredicate(
planner.NewNodeExpressionAtomPredicate("",
planner.NewNodeExpressionAtom("", planner.ExpressionAtomWithConstant(
planner.NewNodeConstant("", planner.WithDecimalLiteral(c))))))),
op),
),
)),
)
return n
}
func Test_getOutputFieldsOrMatchCountRule(t *testing.T) {
t.Run("match count rule", func(t *testing.T) {
fl := []*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFunctionCall(
planner.NewNodeFunctionCall("", planner.WithAgg(
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
planner.NewNodeCount(""))))))),
}
_, match, err := getOutputFieldsOrMatchCountRule(fl)
assert.NoError(t, err)
assert.True(t, match)
})
t.Run("star *, not supported", func(t *testing.T) {
fl := []*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithStar()),
}
_, _, err := getOutputFieldsOrMatchCountRule(fl)
assert.Error(t, err)
})
t.Run("combined", func(t *testing.T) {
fl := []*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFunctionCall(
planner.NewNodeFunctionCall("", planner.WithAgg(
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
planner.NewNodeCount(""))))))),
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field"))),
}
_, _, err := getOutputFieldsOrMatchCountRule(fl)
assert.Error(t, err)
})
t.Run("alias, not supported", func(t *testing.T) {
fl := []*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))),
}
_, _, err := getOutputFieldsOrMatchCountRule(fl)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
fl := []*planner.NodeSelectElement{
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field1"))),
planner.NewNodeSelectElement("", planner.WithFullColumnName(
planner.NewNodeFullColumnName("", "field2"))),
}
outputFields, match, err := getOutputFieldsOrMatchCountRule(fl)
assert.NoError(t, err)
assert.False(t, match)
assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields)
})
}
func Test_wrapCountResult(t *testing.T) {
sqlRes := wrapCountResult(100, "count(*)")
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows))
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type())
}
func Test_wrapQueryResults(t *testing.T) {
res := &milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "field",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
CollectionName: "test",
}
sqlRes := wrapQueryResults(res)
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, 4, len(sqlRes.Rows))
assert.Equal(t, "field", sqlRes.Fields[0].Name)
assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type())
}
func Test_getSQLField(t *testing.T) {
f := &schemapb.FieldData{
FieldName: "a",
Type: schemapb.DataType_Int64,
}
sf := getSQLField("t", f)
assert.Equal(t, "a", sf.Name)
assert.Equal(t, query.Type_INT64, sf.Type)
assert.Equal(t, "t", sf.Table)
}
func Test_toSQLType(t *testing.T) {
type args struct {
t schemapb.DataType
}
tests := []struct {
name string
args args
want query.Type
}{
{
args: args{
t: schemapb.DataType_Bool,
},
want: query.Type_UINT8,
},
{
args: args{
t: schemapb.DataType_Int8,
},
want: query.Type_INT8,
},
{
args: args{
t: schemapb.DataType_Int16,
},
want: query.Type_INT16,
},
{
args: args{
t: schemapb.DataType_Int32,
},
want: query.Type_INT32,
},
{
args: args{
t: schemapb.DataType_Int64,
},
want: query.Type_INT64,
},
{
args: args{
t: schemapb.DataType_Float,
},
want: query.Type_FLOAT32,
},
{
args: args{
t: schemapb.DataType_Double,
},
want: query.Type_FLOAT64,
},
{
args: args{
t: schemapb.DataType_VarChar,
},
want: query.Type_VARCHAR,
},
{
args: args{
t: schemapb.DataType_FloatVector,
},
want: query.Type_NULL_TYPE,
},
{
args: args{
t: schemapb.DataType_BinaryVector,
},
want: query.Type_NULL_TYPE,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t)
})
}
}