<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.10" level="project" />
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.68" level="project" />
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.8.0-SNAPSHOT" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.27.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.27.2" level="project" />
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.11.0" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:guava:28.1-android" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.27.2" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.11" level="project" />
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
<orderEntry type="library" name="Maven: org.testng:testng:6.14.3" level="project" />
<orderEntry type="library" name="Maven: com.beust:jcommander:1.72" level="project" />
<orderEntry type="library" name="Maven: org.apache-extras.beanshell:bsh:2.0b6" level="project" />
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.68" level="project" />
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.73" level="project" />
<orderEntry type="library" name="Maven: junit:junit:4.13" level="project" />
<orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.8.0-SNAPSHOT" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven.plugins:maven-gpg-plugin:1.6" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-api:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-project:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-settings:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-profile:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-artifact-manager:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven.wagon:wagon-provider-api:1.0-beta-6" level="project" />
<orderEntry type="library" name="Maven: backport-util-concurrent:backport-util-concurrent:3.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-registry:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-interpolation:1.11" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-container-default:1.0-alpha-9-stable-1" level="project" />
<orderEntry type="library" name="Maven: classworlds:classworlds:1.1-alpha-2" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-artifact:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-repository-metadata:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-model:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-utils:3.0.20" level="project" />
<orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-sec-dispatcher:1.4" level="project" />
<orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-cipher:1.4" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.27.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.27.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.android:annotations:" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.perfmark:perfmark-api:0.19.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.9.0-SNAPSHOT" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.30.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.30.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.30.2" level="project" />
<orderEntry type="library" name="Maven: com.google.code.findbugs:jsr305:3.0.2" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.mojo:animal-sniffer-annotations:1.18" level="project" />
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.11.0" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:guava:28.1-android" level="project" />
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.12.0" level="project" />
<orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.17.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.30.2" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:guava:28.2-android" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:failureaccess:1.0.1" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:listenablefuture:9999.0-empty-to-avoid-conflict-with-guava" level="project" />
<orderEntry type="library" name="Maven: org.checkerframework:checker-compat-qual:2.5.5" level="project" />
<orderEntry type="library" name="Maven: com.google.j2objc:j2objc-annotations:1.3" level="project" />
<orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.17.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.27.2" level="project" />
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java-util:3.11.0" level="project" />
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
<orderEntry type="library" name="Maven: com.google.errorprone:error_prone_annotations:2.3.4" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.6" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.4" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: org.codehaus.mojo:animal-sniffer-annotations:1.18" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.30.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.30.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.30.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.android:annotations:" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.perfmark:perfmark-api:0.19.0" level="project" />
<orderEntry type="library" name="Maven: org.json:json:20190722" level="project" />
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.30" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.12.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.12.1" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.12.1" level="project" />
<orderEntry type="library" name="Maven: org.testcontainers:testcontainers:1.14.3" level="project" />
<orderEntry type="library" name="Maven: org.jetbrains:annotations:19.0.0" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.20" level="project" />
<orderEntry type="library" name="Maven: org.rnorth.duct-tape:duct-tape:1.0.8" level="project" />
<orderEntry type="library" name="Maven: org.rnorth.visible-assertions:visible-assertions:2.1.2" level="project" />
<orderEntry type="library" name="Maven: org.rnorth:tcp-unix-socket-proxy:1.0.2" level="project" />
<orderEntry type="library" name="Maven: com.kohlschutter.junixsocket:junixsocket-native-common:2.0.4" level="project" />
<orderEntry type="library" name="Maven: org.scijava:native-lib-loader:2.0.2" level="project" />
<orderEntry type="library" name="Maven: com.kohlschutter.junixsocket:junixsocket-common:2.0.4" level="project" />
<orderEntry type="library" name="Maven: net.java.dev.jna:jna-platform:5.5.0" level="project" />
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.5.0" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.slf4j:slf4j-simple:1.7.30" level="project" />
@ -118,6 +118,17 @@
<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
@ -3,6 +3,7 @@ package com;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import io.milvus.client.*;
public final class Constants {
@ -20,37 +21,39 @@ public final class Constants {
public static final double epsilon = 0.001;
public static final int segmentRowLimit = 5000;
public static final String fieldNameKey = "name";
public static final String vectorType = "float";
public static final String defaultMetricType = "L2";
public static final String binaryVectorType = "binary";
public static final String indexType = "IVF_SQ8";
public static final MetricType defaultMetricType = MetricType.L2;
public static final String defaultIndexType = "FLAT";
public static final IndexType indexType = IndexType.IVF_SQ8;
public static final String defaultBinaryIndexType = "BIN_FLAT";
public static final IndexType defaultIndexType = IndexType.FLAT;
public static final String defaultBinaryMetricType = "JACCARD";
public static final IndexType defaultBinaryIndexType = IndexType.BIN_IVF_FLAT;
public static final String floatFieldName = "float_vector";
public static final MetricType defaultBinaryMetricType = MetricType.JACCARD;
public static final String binaryFieldName = "binary_vector";
public static final String intFieldName = "int64";
public static final String indexParam = Utils.setIndexParam(indexType, "L2", n_list);
public static final String floatFieldName = "float";
public static final String binaryIndexParam = Utils.setIndexParam(defaultBinaryIndexType, defaultBinaryMetricType, n_list);
public static final String floatVectorFieldName = "float_vector";
public static final String binaryVectorFieldName = "binary_vector";
public static final List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
public static final List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
public static final List<Map<String,Object>> defaultFields = Utils.genDefaultFields(dimension,false);
public static final Map<String, List> defaultEntities = Utils.genDefaultEntities(nb, vectors);
public static final List<Map<String,Object>> defaultBinaryFields = Utils.genDefaultFields(dimension,true);
public static final List<Map<String,Object>> defaultEntities = Utils.genDefaultEntities(dimension, nb, vectors);
public static final List<Map<String,Object>> defaultBinaryEntities = Utils.genDefaultBinaryEntities(dimension, nb, vectorsBinary);
public static final Map<String, List> defaultBinaryEntities = Utils.genDefaultBinaryEntities(nb, vectorsBinary);
public static final String searchParam = Utils.setSearchParam(defaultMetricType, vectors.subList(0, nq), topk, n_probe);
@ -3,11 +3,7 @@ package com;
import io.milvus.client.*;
import org.apache.commons.cli.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.SkipException;
import org.testng.TestNG;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.AfterSuite;
import org.testng.annotations.AfterTest;
import org.testng.annotations.DataProvider;
import org.testng.xml.XmlClass;
import org.testng.xml.XmlSuite;
@ -15,16 +11,17 @@ import org.testng.xml.XmlTest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class MainClass {
private static String HOST = "";
// private static String HOST = "";
private static int PORT = 19530;
private int segmentRowCount = 5000;
private static ConnectParam CONNECT_PARAM = new ConnectParam.Builder()
private static MilvusClient client;
public static void setHost(String host) {
MainClass.HOST = host;
@ -34,98 +31,85 @@ public class MainClass {
MainClass.PORT = port;
public static String getHost() {
return MainClass.HOST;
public static int getPort() {
return MainClass.PORT;
public static Object[][] defaultConnectArgs(){
return new Object[][]{{HOST, PORT}};
public Object[][] connectInstance() throws ConnectFailedException {
MilvusClient client = new MilvusGrpcClient();
public Object[][] connectInstance() throws Exception {
ConnectParam connectParam = new ConnectParam.Builder()
client = new MilvusGrpcClient(connectParam).withLogging();
String collectionName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, collectionName}};
public Object[][] disConnectInstance() throws ConnectFailedException {
public Object[][] disConnectInstance(){
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
try {
} catch (InterruptedException e) {
client = new MilvusGrpcClient(CONNECT_PARAM).withLogging();
String collectionName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, collectionName}};
private Object[][] genCollection(boolean isBinary, boolean autoId) throws ConnectFailedException {
private Object[][] genCollection(boolean isBinary, boolean autoId) throws Exception {
Object[][] collection;
String collectionName = Utils.genUniqueStr("collection");
List<Map<String, Object>> defaultFields = Utils.genDefaultFields(Constants.dimension,isBinary);
String jsonParams = String.format("{\"segment_row_count\": %s, \"auto_id\": %s}",segmentRowCount, autoId);
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
CollectionMapping cm = new CollectionMapping.Builder(collectionName)
Response res = client.createCollection(cm);
if (!res.ok()) {
throw new SkipException("Collection created failed");
client = new MilvusGrpcClient(CONNECT_PARAM).withLogging();
CollectionMapping cm = CollectionMapping
.addField(Constants.intFieldName, DataType.INT64)
.addField(Constants.floatFieldName, DataType.FLOAT)
.setParamsInJson(new JsonBuilder()
.param("segment_row_limit", segmentRowCount)
.param("auto_id", autoId)
if (isBinary) {
cm.addVectorField("binary_vector", DataType.VECTOR_BINARY, Constants.dimension);
} else {
cm.addVectorField("float_vector", DataType.VECTOR_FLOAT, Constants.dimension);
collection = new Object[][]{{client, collectionName}};
return collection;
public Object[][] provideCollection() throws ConnectFailedException, InterruptedException {
public Object[][] provideCollection() throws Exception, InterruptedException {
Object[][] collection = genCollection(false,true);
return collection;
// List<String> tableNames = client.listCollections().getCollectionNames();
// for (int j = 0; j < tableNames.size(); ++j
// ) {
// client.dropCollection(tableNames.get(j));
// }
// Thread.currentThread().sleep(2000);
public Object[][] provideIdCollection() throws ConnectFailedException, InterruptedException {
public Object[][] provideIdCollection() throws Exception, InterruptedException {
Object[][] idCollection = genCollection(false,false);
return idCollection;
public Object[][] provideBinaryCollection() throws ConnectFailedException, InterruptedException {
public Object[][] provideBinaryCollection() throws Exception, InterruptedException {
Object[][] binaryCollection = genCollection(true,true);
return binaryCollection;
public Object[][] provideBinaryIdCollection() throws ConnectFailedException, InterruptedException {
public Object[][] provideBinaryIdCollection() throws Exception, InterruptedException {
Object[][] binaryIdCollection = genCollection(true,false);
return binaryIdCollection;
public void dropCollection(){
// MilvusClient client = new MilvusGrpcClient();
// List<String> collectionNames = client.listCollections().getCollectionNames();
// collectionNames.forEach(client::dropCollection);
System.out.println("after suite");
public void after(){
System.out.println("after method");
public static void main(String[] args) {
CommandLineParser parser = new DefaultParser();
@ -156,19 +140,20 @@ public class MainClass {
List<XmlClass> classes = new ArrayList<XmlClass>();
// classes.add(new XmlClass("com.TestPing"));
// classes.add(new XmlClass("com.TestAddVectors"));
classes.add(new XmlClass("com.TestConnect"));
// classes.add(new XmlClass("com.TestDeleteVectors"));
// classes.add(new XmlClass("com.TestInsertEntities"));
// classes.add(new XmlClass("com.TestConnect"));
// classes.add(new XmlClass("com.TestDeleteEntities"));
// classes.add(new XmlClass("com.TestIndex"));
// classes.add(new XmlClass("com.TestCompact"));
// classes.add(new XmlClass("com.TestSearchVectors"));
classes.add(new XmlClass("com.TestCollection"));
// classes.add(new XmlClass("com.TestSearchEntities"));
// classes.add(new XmlClass("com.TestCollection"));
// classes.add(new XmlClass("com.TestCollectionCount"));
// classes.add(new XmlClass("com.TestFlush"));
// classes.add(new XmlClass("com.TestPartition"));
// classes.add(new XmlClass("com.TestGetVectorByID"));
// classes.add(new XmlClass("com.TestGetEntityByID"));
// classes.add(new XmlClass("com.TestCollectionInfo"));
// classes.add(new XmlClass("com.TestSearchByIds"));
classes.add(new XmlClass("com.TestBeforeAndAfter"));
test.setXmlClasses(classes) ;
package com;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import io.milvus.client.exception.ClientSideMilvusException;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class TestCollection {
int segmentRowCount = 5000;
int dimension = 128;
public MilvusClient setUp() throws ConnectFailedException {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
return client;
public void tearDown() throws ConnectFailedException {
MilvusClient client = setUp();
List<String> collectionNames = client.listCollections().getCollectionNames();
// collectionNames.forEach(collection -> {client.dropCollection(collection);});
for(String collection: collectionNames){
System.out.print(collection+" ");
System.out.println("After Test");
String intFieldName = Constants.intFieldName;
String floatFieldName = Constants.floatFieldName;
String floatVectorFieldName = Constants.floatVectorFieldName;
Boolean autoId = true;
// case-01
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testCreateCollection(MilvusClient client, String collectionName){
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
Response res = client.createCollection(collectionSchema);
Assert.assertEquals(res.ok(), true);
public void testCreateCollection(MilvusClient client, String collectionName) {
// Generate connection instance
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionName, true,false);
Assert.assertEquals(client.hasCollection(collectionName), true);
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void testCreateCollectionDisconnect(MilvusClient client, String collectionName){
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
Response res = client.createCollection(collectionSchema);
// case-02
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testCreateCollectionDisconnect(MilvusClient client, String collectionName) {
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionName, true, false);
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testCreateCollectionRepeatably(MilvusClient client, String collectionName){
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
Response res = client.createCollection(collectionSchema);
Assert.assertEquals(res.ok(), true);
Response resNew = client.createCollection(collectionSchema);
Assert.assertEquals(resNew.ok(), false);
// case-03
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCreateCollectionRepeatably(MilvusClient client, String collectionName) {
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionName, true, false);
Assert.assertEquals(client.hasCollection(collectionName), true);
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testCreateCollectionWrongParams(MilvusClient client, String collectionName){
// case-04
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCreateCollectionWrongParams(MilvusClient client, String collectionName) {
Integer dim = 0;
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
Response res = client.createCollection(collectionSchema);
Assert.assertEquals(res.ok(), false);
CollectionMapping cm = CollectionMapping.create(collectionName)
.addField(intFieldName, DataType.INT64)
.addField(floatFieldName, DataType.FLOAT)
.addVectorField(floatVectorFieldName, DataType.VECTOR_FLOAT, dim)
.setParamsInJson(new JsonBuilder()
.param("segment_row_limit", Constants.segmentRowLimit)
.param("auto_id", autoId)
// case-05
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testShowCollections(MilvusClient client, String collectionName){
public void testShowCollections(MilvusClient client, String collectionName) {
Integer collectionNum = 10;
ListCollectionsResponse res = null;
List<String> originCollections = new ArrayList<>();
for (int i = 0; i < collectionNum; ++i) {
String collectionNameNew = collectionName+"_"+Integer.toString(i);
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionNameNew)
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
List<String> collectionNames = client.listCollections().getCollectionNames();
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionNameNew, true, false);
List<String> listCollections = client.listCollections();
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void testShowCollectionsWithoutConnect(MilvusClient client, String collectionName){
ListCollectionsResponse res = client.listCollections();
// case-06
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testShowCollectionsWithoutConnect(MilvusClient client, String collectionName) {
// case-07
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDropCollection(MilvusClient client, String collectionName) throws InterruptedException {
Response res = client.dropCollection(collectionName);
List<String> collectionNames = client.listCollections().getCollectionNames();
public void testDropCollection(MilvusClient client, String collectionName) {
Assert.assertEquals(client.hasCollection(collectionName), false);
// Thread.currentThread().sleep(1000);
List<String> collectionNames = client.listCollections();
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-08
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDropCollectionNotExisted(MilvusClient client, String collectionName) {
Response res = client.dropCollection(collectionName+"_");
List<String> collectionNames = client.listCollections().getCollectionNames();
List<String> collectionNames = client.listCollections();
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-09
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testDropCollectionWithoutConnect(MilvusClient client, String collectionName) {
Response res = client.dropCollection(collectionName);
// case-10
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDescribeCollection(MilvusClient client, String collectionName) {
GetCollectionInfoResponse res = client.getCollectionInfo(collectionName);
CollectionMapping collectionSchema = res.getCollectionMapping().get();
List<Map<String,Object>> fields = (List<Map<String, Object>>) collectionSchema.getFields();
CollectionMapping info = client.getCollectionInfo(collectionName);
List<Map<String,Object>> fields = info.getFields();
int dim = 0;
for(Map<String,Object> field: fields){
if ("float_vector".equals(field.get("field"))) {
if (floatVectorFieldName.equals(field.get(Constants.fieldNameKey))) {
JSONObject jsonObject = JSONObject.parseObject(field.get("params").toString());
String dimParams = jsonObject.getString("params");
dim = Utils.getParam(dimParams,"dim");
dim = jsonObject.getIntValue("dim");
String segmentParams = collectionSchema.getParamsInJson();
Assert.assertEquals(dim, dimension);
Assert.assertEquals(collectionSchema.getCollectionName(), collectionName);
Assert.assertEquals(Utils.getParam(segmentParams,"segment_row_count"), segmentRowCount);
JSONObject params = JSONObject.parseObject(info.getParamsInJson().toString());
Assert.assertEquals(dim, Constants.dimension);
Assert.assertEquals(info.getCollectionName(), collectionName);
Assert.assertEquals(params.getIntValue("segment_row_limit"), Constants.segmentRowLimit);
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-11
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testDescribeCollectionWithoutConnect(MilvusClient client, String collectionName) {
GetCollectionInfoResponse res = client.getCollectionInfo(collectionName);
// case-12
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testHasCollectionNotExisted(MilvusClient client, String collectionName) {
HasCollectionResponse res = client.hasCollection(collectionName+"_");
String collectionNameNew = collectionName+"_";
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-13
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testHasCollectionWithoutConnect(MilvusClient client, String collectionName) {
HasCollectionResponse res = client.hasCollection(collectionName);
// case-14
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testHasCollection(MilvusClient client, String collectionName) {
HasCollectionResponse res = client.hasCollection(collectionName);
package com;
import io.milvus.client.*;
import io.milvus.client.exception.ClientSideMilvusException;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
public class TestCollectionCount {
int segmentRowCount = 5000;
int dimension = Constants.dimension;
int nb = Constants.nb;
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCollectionCountNoVectors(MilvusClient client, String collectionName) {
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
Assert.assertEquals(client.countEntities(collectionName), 0);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-02
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCollectionCountCollectionNotExisted(MilvusClient client, String collectionName) {
CountEntitiesResponse res = client.countEntities(collectionName+"_");
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-03
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testCollectionCountWithoutConnect(MilvusClient client, String collectionName) {
CountEntitiesResponse res = client.countEntities(collectionName+"_");
// case-04
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCollectionCount(MilvusClient client, String collectionName) throws InterruptedException {
InsertParam insertParam =
new InsertParam.Builder(collectionName)
InsertResponse insertResponse = client.insert(insertParam);
// Insert returns a list of entity ids that you will be using (if you did not supply the yourself) to reference the entities you just inserted
List<Long> vectorIds = insertResponse.getEntityIds();
// Add vectors
Response flushResponse = client.flush(collectionName);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testCollectionCountBinary(MilvusClient client, String collectionName) throws InterruptedException {
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionName)
public void testCollectionCount(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
// Insert returns a list of entity ids that you will be using (if you did not supply the yourself) to reference the entities you just inserted
Assert.assertEquals(client.countEntities(collectionName), nb);
// case-05
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testCollectionCountBinary(MilvusClient client, String collectionName) {
// Add vectors
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Assert.assertEquals(client.countEntities(collectionName), nb);
// case-06
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCollectionCountMultiCollections(MilvusClient client, String collectionName) throws InterruptedException {
public void testCollectionCountMultiCollections(MilvusClient client, String collectionName) {
Integer collectionNum = 10;
CountEntitiesResponse res;
for (int i = 0; i < collectionNum; ++i) {
String collectionNameNew = collectionName + "_" + i;
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionNameNew)
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
Response cteateRes = client.createCollection(collectionSchema);
Assert.assertEquals(cteateRes.ok(), true);
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionNameNew, true, false);
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionNameNew)
InsertResponse insertRes = client.insert(insertParam);
Assert.assertEquals(insertRes.ok(), true);
Response flushRes = client.flush(collectionNameNew);
Assert.assertEquals(flushRes.ok(), true);
InsertParam insertParam = Utils.genInsertParam(collectionNameNew);
List<Long> ids = client.insert(insertParam);
for (int i = 0; i < collectionNum; ++i) {
String collectionNameNew = collectionName + "_" + i;
res = client.countEntities(collectionNameNew);
Assert.assertEquals(res.getCollectionEntityCount(), nb);
Assert.assertEquals(client.countEntities(collectionNameNew), nb);
package com;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import org.testng.annotations.Test;
import java.util.Collections;
import java.util.List;
public class TestCollectionInfo {
int nb = Constants.nb;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetEntityIdsAfterDeleteEntities(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName)
InsertResponse resInsert = client.insert(insertParam);
List<Long> idsBefore = resInsert.getEntityIds();
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
Response res = client.getCollectionStats(collectionName);
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
int rowCount = collectionInfo.getIntValue("row_count");
assert(rowCount == nb-1);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetEntityIdsAterDeleteEntitiesIndexed(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName)
InsertResponse resInsert = client.insert(insertParam);
Index index = new Index.Builder(collectionName, "float_vector")
Response createIndexResponse = client.createIndex(index);
List<Long> idsBefore = resInsert.getEntityIds();
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
Response res = client.getCollectionStats(collectionName);
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
int rowCount = collectionInfo.getIntValue("row_count");
assert(rowCount == nb-1);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testGetEntityIdsAfterDeleteEntitiesBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName)
InsertResponse resInsert = client.insert(insertParam);
List<Long> idsBefore = resInsert.getEntityIds();
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
Response res = client.getCollectionStats(collectionName);
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
int rowCount = collectionInfo.getIntValue("row_count");
assert(rowCount == nb-1);
@ -0,0 +1,62 @@
package com;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import org.testng.annotations.Test;
import java.util.Collections;
import java.util.List;
public class TestCollectionStats {
int nb = Constants.nb;
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetEntityIdsAfterDeleteEntities(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> idsBefore = client.insert(insertParam);
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
String stats = client.getCollectionStats(collectionName);
JSONObject collectionStats = JSONObject.parseObject(stats);
int rowCount = collectionStats.getIntValue("row_count");
assert(rowCount == nb - 1);
// case-02
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetEntityIdsAfterDeleteEntitiesIndexed(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> idsBefore = client.insert(insertParam);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
String stats = client.getCollectionStats(collectionName);
JSONObject collectionStats = JSONObject.parseObject(stats);
int rowCount = collectionStats.getIntValue("row_count");
assert(rowCount == nb - 1);
// case-03
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testGetEntityIdsAfterDeleteEntitiesBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
List<Long> idsBefore = client.insert(insertParam);
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
String stats = client.getCollectionStats(collectionName);
JSONObject collectionStats = JSONObject.parseObject(stats);
int rowCount = collectionStats.getIntValue("row_count");
assert(rowCount == nb - 1);
package com;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
@ -10,78 +10,156 @@ import java.util.List;
public class TestCompact {
int nb = Constants.nb;
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCompactAfterDelete(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
List<Long> ids = Utils.initData(client, collectionName);
client.deleteEntityByID(collectionName, ids);
Response res_delete = client.deleteEntityByID(collectionName, ids);
CompactParam compactParam = new CompactParam.Builder(collectionName).build();
Response res_compact = client.compact(compactParam);
Response statsResponse = client.getCollectionStats(collectionName);
assert (statsResponse.ok());
JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
String statsResponse = client.getCollectionStats(collectionName);
JSONObject jsonObject = JSONObject.parseObject(statsResponse);
Assert.assertEquals(jsonObject.getIntValue("data_size"), 0);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
Assert.assertEquals(client.countEntities(collectionName), 0);
// case-02
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testCompactAfterDeleteBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
List<Long> ids = Utils.initBinaryData(client, collectionName);
client.deleteEntityByID(collectionName, ids);
Response res_delete = client.deleteEntityByID(collectionName, ids);
CompactParam compactParam = new CompactParam.Builder(collectionName).build();
Response res_compact = client.compact(compactParam);
Response statsResponse = client.getCollectionStats(collectionName);
assert (statsResponse.ok());
JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
String stats = client.getCollectionStats(collectionName);
JSONObject jsonObject = JSONObject.parseObject(stats);
Assert.assertEquals(jsonObject.getIntValue("data_size"), 0);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
Assert.assertEquals(client.countEntities(collectionName), 0);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-03
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCompactNoCollection(MilvusClient client, String collectionName) {
String name = "";
CompactParam compactParam = new CompactParam.Builder(name).build();
Response res_compact = client.compact(compactParam);
// case-04
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCompactEmptyCollection(MilvusClient client, String collectionName) {
CompactParam compactParam = new CompactParam.Builder(collectionName).build();
Response res_compact = client.compact(compactParam);
String stats = client.getCollectionStats(collectionName);
JSONObject jsonObject = JSONObject.parseObject(stats);
int data_size = jsonObject.getIntValue("data_size");
Assert.assertEquals(0, data_size);
// case-05
// TODO delete not correct
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCompactThresholdLessThanDeleted(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
Response deleteRes = client.deleteEntityByID(collectionName, res.getEntityIds().subList(0, nb/4));
assert (deleteRes.ok());
Response resBefore = client.getCollectionStats(collectionName);
JSONObject segmentsBefore = (JSONObject)Utils.parseJsonArray(resBefore.getMessage(), "segments").get(0);
CompactParam compactParam = new CompactParam.Builder(collectionName).withThreshold(0.3).build();
Response resCompact = client.compact(compactParam);
assert (resCompact.ok());
Response resAfter = client.getCollectionStats(collectionName);
JSONObject segmentsAfter = (JSONObject)Utils.parseJsonArray(resAfter.getMessage(), "segments").get(0);
Assert.assertEquals(segmentsBefore.get("data_size"), segmentsAfter.get("data_size"));
int segmentRowLimit = nb+1000;
String collectionNameNew = collectionName+"_";
CollectionMapping cm = CollectionMapping.create(collectionNameNew)
.addField(Constants.intFieldName, DataType.INT64)
.addField(Constants.floatFieldName, DataType.FLOAT)
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, Constants.dimension)
.setParamsInJson(new JsonBuilder()
.param("segment_row_limit", segmentRowLimit)
.param("auto_id", true)
List<Long> ids = Utils.initData(client, collectionNameNew);
client.deleteEntityByID(collectionNameNew, ids.subList(0, nb / 4));
Assert.assertEquals(client.countEntities(collectionNameNew), nb - (nb / 4));
// before compact
String stats = client.getCollectionStats(collectionNameNew);
JSONObject segmentsBefore = (JSONObject)Utils.parseJsonArray(stats, "segments").get(0);
// after compact
String statsAfter = client.getCollectionStats(collectionNameNew);
JSONObject segmentsAfter = (JSONObject)Utils.parseJsonArray(statsAfter, "segments").get(0);
Assert.assertEquals(segmentsAfter.get("data_size"), segmentsBefore.get("data_size"));
// case-06
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCompactInvalidThreshold(MilvusClient client, String collectionName) {
List<Long> ids = Utils.initData(client, collectionName);
client.deleteEntityByID(collectionName, ids);
// // case-07, test CompactAsync callback
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// public void testCompactAsyncAfterDelete(MilvusClient client, String collectionName) {
// // define callback
// FutureCallback<Response> futureCallback = new FutureCallback<Response>() {
// @Override
// public void onSuccess(Response compactResponse) {
// assert(compactResponse != null);
// assert(compactResponse.ok());
// Response statsResponse = client.getCollectionStats(collectionName);
// assert(statsResponse.ok());
// JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
// Assert.assertEquals(jsonObject.getIntValue("data_size"), 0);
// Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
// }
// @Override
// public void onFailure(Throwable t) {
// System.out.println(t.getMessage());
// Assert.assertTrue(false);
// }
// };
// InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
// InsertResponse res = client.insert(insertParam);
// assert(res.getResponse().ok());
// List<Long> ids = res.getEntityIds();
// client.flush(collectionName);
// Response res_delete = client.deleteEntityByID(collectionName, ids);
// assert(res_delete.ok());
// client.flush(collectionName);
// CompactParam compactParam = new CompactParam.Builder(collectionName).build();
// // call compactAsync
// ListenableFuture<Response> compactResponseFuture = client.compactAsync(compactParam);
// Futures.addCallback(compactResponseFuture, futureCallback, MoreExecutors.directExecutor());
// // execute before callback
// Response statsResponse = client.getCollectionStats(collectionName);
// assert(statsResponse.ok());
// JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
// Assert.assertTrue(jsonObject.getIntValue("data_size") > 0);
// Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
// }
// // case-08, test CompactAsync callback with invalid collection name
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// public void testCompactAsyncNoCollection(MilvusClient client, String collectionName) {
// // define callback
// FutureCallback<Response> futureCallback = new FutureCallback<Response>() {
// @Override
// public void onSuccess(Response compactResponse) {
// assert(compactResponse != null);
// assert(!compactResponse.ok());
// }
// @Override
// public void onFailure(Throwable t) {
// System.out.println(t.getMessage());
// Assert.assertTrue(false);
// }
// };
// String name = "";
// CompactParam compactParam = new CompactParam.Builder(name).build();
// // call compactAsync
// ListenableFuture<Response> compactResponseFuture = client.compactAsync(compactParam);
// Futures.addCallback(compactResponseFuture, futureCallback, MoreExecutors.directExecutor());
// }
package com;
import io.milvus.client.exception.ClientSideMilvusException;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.concurrent.TimeUnit;
public class TestConnect {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void testConnect(String host, int port) throws ConnectFailedException {
public void testConnect(String host, int port) throws Exception {
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
Response res = client.connect(connectParam);
// assert(client.isConnected());
MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void testConnectRepeat(String host, int port) {
MilvusGrpcClient client = new MilvusGrpcClient();
Response res = null;
try {
ConnectParam connectParam = new ConnectParam.Builder()
res = client.connect(connectParam);
res = client.connect(connectParam);
} catch (ConnectFailedException e) {
assert (res.ok());
// assert(client.isConnected());
ConnectParam connectParam = new ConnectParam.Builder()
MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
MilvusClient client1 = new MilvusGrpcClient(connectParam).withLogging();
public void testConnectInvalidConnect_args(String ip, int port) {
MilvusClient client = new MilvusGrpcClient();
Response res = null;
try {
ConnectParam connectParam = new ConnectParam.Builder()
res = client.connect(connectParam);
} catch (Exception e) {
Assert.assertEquals(res, null);
// assert(!client.isConnected());
// TODO timeout
@Test(dataProvider="InvalidConnectArgs", expectedExceptions = {ClientSideMilvusException.class, IllegalArgumentException.class})
public void testConnectInvalidConnectArgs(String ip, int port) {
ConnectParam connectParam = new ConnectParam.Builder()
.withKeepAliveTimeout(1, TimeUnit.MILLISECONDS)
MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
@ -62,27 +46,21 @@ public class TestConnect {
{"", port},
{"", port},
{"1.2.2", port},
// {"中文", port},
// {"www.baidu.com", 100000},
{"中文", port},
{"www.baidu.com", 100000},
{"", 100000},
{"www.baidu.com", 80},
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testDisconnect(MilvusClient client, String collectionName){
// assert(!client.isConnected());
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void testDisconnectRepeatably(MilvusClient client, String collectionName){
Response res = null;
try {
res = client.disconnect();
} catch (InterruptedException e) {
// assert(!client.isConnected());
// // TODO
// @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
// public void testDisconnectRepeatably(MilvusClient client, String collectionName){
// client.close();
// }
package com;
import io.milvus.client.*;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
public class TestDeleteEntities {
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDeleteEntities(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Response res_delete = client.deleteEntityByID(collectionName, ids);
client.deleteEntityByID(collectionName, ids);
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
Assert.assertEquals(client.countEntities(collectionName), 0);
// case-02
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDeleteSingleEntity(MilvusClient client, String collectionName) {
List<List<Float>> del_vector = new ArrayList<>();
List<Long> del_ids = new ArrayList<>();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
Response res_delete = client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
List<Long> ids = Utils.initData(client, collectionName);
client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), Constants.nb - 1);
Assert.assertEquals(client.countEntities(collectionName), Constants.nb - 1);
// Assert getEntityByID
GetEntityByIDResponse res_get = client.getEntityByID(collectionName, del_ids);
Assert.assertEquals(res_get.getValidIds().size(), 0);
Map<Long, Map<String, Object>> resEntity = client.getEntityByID(collectionName, Collections.singletonList(ids.get(0)));
Assert.assertEquals(resEntity.size(), 0);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-03
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDeleteEntitiesCollectionNotExisted(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
String collectionNameNew = Utils.genUniqueStr(collectionName);
Response res_delete = client.deleteEntityByID(collectionNameNew, ids);
client.deleteEntityByID(collectionNameNew, new ArrayList<Long>());
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-04
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDeleteEntitiesEmptyCollection(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
List<Long> entityIds = LongStream.range(0, Constants.nb).boxed().collect(Collectors.toList());
Response res_delete = client.deleteEntityByID(collectionNameNew, entityIds);
client.deleteEntityByID(collectionNameNew, entityIds);
// case-05
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDeleteEntityIdNotExisted(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = new ArrayList<Long>();
List<Long> ids = Utils.initData(client, collectionName);
List<Long> delIds = new ArrayList<Long>();
client.deleteEntityByID(collectionName, delIds);
Response res_delete = client.deleteEntityByID(collectionName, ids);
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), Constants.nb);
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName), Constants.nb);
// case-06
// Below tests binary vectors
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testDeleteEntitiesBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
Response res_delete = client.deleteEntityByID(collectionName, ids);
List<Long> ids = Utils.initBinaryData(client, collectionName);
client.deleteEntityByID(collectionName, ids);
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
Assert.assertEquals(client.countEntities(collectionName), 0);
package com;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.concurrent.ExecutionException;
import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.List;
public class TestFlush {
int segmentRowCount = 50;
int nb = Constants.nb;
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testFlushCollectionNotExisted(MilvusClient client, String collectionName) {
String newCollection = "not_existed";
Response res = client.flush(newCollection);
try {
} catch (Exception e) {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testFlushEmptyCollection(MilvusClient client, String collectionName) {
Response res = client.flush(collectionName);
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
@ -32,64 +32,66 @@ public class TestFlush {
int collectionNum = 10;
for (int i = 0; i < collectionNum; i++) {
CollectionMapping collectionSchema = new CollectionMapping.Builder(names.get(i))
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFields(Constants.defaultEntities).build();
CollectionMapping cm = CollectionMapping
.addField("int64", DataType.INT64)
.addField("float", DataType.FLOAT)
.addVectorField("float_vector", DataType.VECTOR_FLOAT, Constants.dimension)
.setParamsInJson(new JsonBuilder()
.param("segment_row_limit", Constants.segmentRowLimit)
.param("auto_id", true)
InsertParam insertParam = Utils.genInsertParam(names.get(i));
System.out.println("Table " + names.get(i) + " created.");
Response res = client.flush(names);
for (int i = 0; i < collectionNum; i++) {
// check row count
Assert.assertEquals(client.countEntities(names.get(i)).getCollectionEntityCount(), nb);
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void testAddCollectionsFlushAsync(MilvusClient client, String collectionName) throws ExecutionException, InterruptedException {
List<String> names = new ArrayList<>();
for (int i = 0; i < 100; i++) {
CollectionMapping collectionSchema = new CollectionMapping.Builder(names.get(i))
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFields(Constants.defaultEntities).build();
System.out.println("Collection " + names.get(i) + " created.");
ListenableFuture<Response> flushResponseFuture = client.flushAsync(names);
for (int i = 0; i < 100; i++) {
// check row count
Assert.assertEquals(client.countEntities(names.get(i)).getCollectionEntityCount(), nb);
Assert.assertEquals(client.countEntities(names.get(i)), nb);
// @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
// public void testAddCollectionsFlushAsync(MilvusClient client, String collectionName) throws ExecutionException, InterruptedException {
// List<String> names = new ArrayList<>();
// for (int i = 0; i < 100; i++) {
// names.add(RandomStringUtils.randomAlphabetic(10));
// CollectionMapping collectionSchema = new CollectionMapping.Builder(names.get(i))
// .withFields(Constants.defaultFields)
// .withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
// .build();
// client.createCollection(collectionSchema);
// InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFields(Constants.defaultEntities).build();
// client.insert(insertParam);
// System.out.println("Collection " + names.get(i) + " created.");
// }
// ListenableFuture<Response> flushResponseFuture = client.flushAsync(names);
// flushResponseFuture.get();
// for (int i = 0; i < 100; i++) {
// // check row count
// Assert.assertEquals(client.countEntities(names.get(i)).getCollectionEntityCount(), nb);
// }
// }
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testAddFlushMultipleTimes(MilvusClient client, String collectionName) {
for (int i = 0; i < 10; i++) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
Response res = client.flush(collectionName);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb * (i+1));
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Assert.assertEquals(client.countEntities(collectionName), nb * (i+1));
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testAddFlushMultipleTimesBinary(MilvusClient client, String collectionName) {
for (int i = 0; i < 10; i++) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
Response res = client.flush(collectionName);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb * (i+1));
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Assert.assertEquals(client.countEntities(collectionName), nb * (i+1));
package com;
import io.milvus.client.*;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.nio.ByteBuffer;
import java.util.*;
public class TestGetEntityByID {
public List<Long> get_ids = Utils.toListIds(1111);
@ -13,77 +14,77 @@ public class TestGetEntityByID {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetEntitiesByIdValid(MilvusClient client, String collectionName) {
int get_length = 100;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, get_length));
assert (res.getResponse().ok());
// assert (res.getValidIds(), ids.subList(0, get_length));
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, ids.subList(0, get_length));
for (int i = 0; i < get_length; i++) {
List<Map<String,Object>> fieldsMap = res.getFieldsMap();
assert (fieldsMap.get(i).get("float_vector").equals(Constants.vectors.get(i)));
Map<String,Object> fieldsMap = resEntities.get(ids.get(i));
assert (fieldsMap.get("float_vector").equals(Constants.vectors.get(i)));
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetEntityByIdAfterDelete(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getEntityIds();
Response res_delete = client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
List<Long> ids = Utils.initData(client, collectionName);
client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(1)));
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
assert (res.getResponse().ok());
assert (res.getFieldsMap().size() == 0);
List<Long> getIds = ids.subList(0,2);
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, getIds);
Assert.assertEquals(resEntities.size(), getIds.size()-1);
Assert.assertEquals(resEntities.get(getIds.get(0)).get(Constants.floatVectorFieldName), Constants.vectors.get(0));
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testGetEntityByIdCollectionNameNotExisted(MilvusClient client, String collectionName) {
String newCollection = "not_existed";
GetEntityByIDResponse res = client.getEntityByID(newCollection, get_ids);
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(newCollection, get_ids);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testGetVectorIdNotExisted(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
GetEntityByIDResponse res = client.getEntityByID(collectionName, get_ids);
assert (res.getFieldsMap().size() == 0);
List<Long> ids = Utils.initData(client, collectionName);
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, get_ids);
Assert.assertEquals(resEntities.size(), 0);
// Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testGetEntityByIdValidBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getEntityIds();
int get_length = 20;
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
List<Long> ids = client.insert(insertParam);
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
assert res.getFieldsMap().get(0).get(Constants.binaryFieldName).equals(Constants.vectorsBinary.get(0).rewind());
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, ids.subList(0, get_length));
for (int i = 0; i < get_length; i++) {
assert (resEntities.get(ids.get(i)).get(Constants.binaryVectorFieldName).equals(vectors.get(i)));
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testGetEntityByIdAfterDeleteBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getEntityIds();
Response res_delete = client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
List<Long> ids = Utils.initBinaryData(client, collectionName);
client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
assert (res.getFieldsMap().size() == 0);
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, ids.subList(0, 1));
Assert.assertEquals(resEntities.size(), 0);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testGetEntityIdNotExistedBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
GetEntityByIDResponse res = client.getEntityByID(collectionName, get_ids);
assert (res.getFieldsMap().size() == 0);
List<Long> ids = Utils.initBinaryData(client, collectionName);
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, get_ids);
Assert.assertEquals(resEntities.size(), 0);
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import io.milvus.client.exception.ClientSideMilvusException;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.logging.Logger;
import java.util.List;
public class TestIndex {
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCreateIndex(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
Response statsResponse = client.getCollectionStats(collectionName);
// TODO: should check getCollectionStats
if(statsResponse.ok()) {
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
Assert.assertEquals(file.get("index_type"), Constants.indexType));
List<Long> ids = Utils.initData(client, collectionName);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
String stats = client.getCollectionStats(collectionName);
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
Assert.assertEquals(file.get("index_type"), Constants.indexType.toString()));
// case-02
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testCreateIndexBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
Index index = new Index.Builder(collectionName, Constants.binaryFieldName).withParamsInJson(Constants.binaryIndexParam).build();
Response res_create = client.createIndex(index);
Response statsResponse = client.getCollectionStats(collectionName);
// TODO: should check getCollectionStats
if(statsResponse.ok()) {
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
Assert.assertEquals(file.get("index_type"), Constants.defaultBinaryIndexType));
List<Long> ids = Utils.initBinaryData(client, collectionName);
Index index = Index
.create(collectionName, Constants.binaryVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
String stats = client.getCollectionStats(collectionName);
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
Assert.assertEquals(file.get("index_type"), Constants.defaultBinaryIndexType.toString()));
// case-03
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCreateIndexRepeatably(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
Response res_create_2 = client.createIndex(index);
List<Long> ids = Utils.initData(client, collectionName);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
// case-04
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCreateIndexWithNoVector(MilvusClient client, String collectionName) {
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-05
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCreateIndexTableNotExisted(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
Index index = new Index.Builder(collectionNameNew, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
Index index = Index
.create(collectionNameNew, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-06
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testCreateIndexWithoutConnect(MilvusClient client, String collectionName) {
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-07
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCreateIndexInvalidNList(MilvusClient client, String collectionName) {
int n_list = 0;
String indexParamNew = Utils.setIndexParam(Constants.indexType, "L2", n_list);
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(indexParamNew).build();
Response res_create = client.createIndex(index);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", n_list).build());
// # 3407
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// #3407
// case-08
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCreateIndexInvalidMetricTypeBinary(MilvusClient client, String collectionName) {
String metric_type = "L2";
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
String indexParamNew = Utils.setIndexParam(Constants.defaultBinaryIndexType, metric_type, Constants.n_list);
Index createIndexParam = new Index.Builder(collectionName, Constants.binaryFieldName).withParamsInJson(indexParamNew).build();
Response res_create = client.createIndex(createIndexParam);
assert (!res_create.ok());
MetricType metric_type = MetricType.L2;
List<Long> ids = Utils.initBinaryData(client, collectionName);
Index index = Index
.create(collectionName, Constants.binaryVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
// #3408
// case-09
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDropIndex(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
Response res_drop = client.dropIndex(collectionName, Constants.floatFieldName);
Response statsResponse = client.getCollectionStats(collectionName);
// TODO: should check getCollectionStats
if(statsResponse.ok()) {
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
filesJsonArray.stream().map(item -> (JSONObject) item).forEach(file->{
assert (!file.containsKey("index_type"));
List<Long> ids = Utils.initData(client, collectionName);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
client.dropIndex(collectionName, Constants.floatVectorFieldName);
String stats = client.getCollectionStats(collectionName);
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
for (Object item : filesJsonArray) {
JSONObject file = (JSONObject) item;
// case-10
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testDropIndexBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
Index index = new Index.Builder(collectionName, Constants.binaryFieldName).withParamsInJson(Constants.binaryIndexParam).build();
Response res_create = client.createIndex(index);
Response res_drop = client.dropIndex(collectionName, Constants.binaryFieldName);
Response statsResponse = client.getCollectionStats(collectionName);
// TODO: should check getCollectionStats
if(statsResponse.ok()) {
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
filesJsonArray.stream().map(item -> (JSONObject) item).forEach(file->{
assert (!file.containsKey("index_type"));
List<Long> ids = Utils.initBinaryData(client, collectionName);
Index index = Index
.create(collectionName, Constants.binaryVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
client.dropIndex(collectionName, Constants.binaryVectorFieldName);
String stats = client.getCollectionStats(collectionName);
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
for (Object item : filesJsonArray) {
JSONObject file = (JSONObject) item;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDropIndexTableNotExisted(MilvusClient client, String collectionName) {
// case-11
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDropIndexCollectionNotExisted(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
Response res_drop = client.dropIndex(collectionNameNew, Constants.floatFieldName);
client.dropIndex(collectionNameNew, Constants.floatVectorFieldName);
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-12
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testDropIndexWithoutConnect(MilvusClient client, String collectionName) {
Response res_drop = client.dropIndex(collectionName, Constants.floatFieldName);
client.dropIndex(collectionName, Constants.floatVectorFieldName);
// // case-13
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// public void testAsyncIndex(MilvusClient client, String collectionName) {
// Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
// ListenableFuture<Response> createIndexResFuture = client.createIndexAsync(index);
// Futures.addCallback(
// createIndexResFuture, new FutureCallback<Response>() {
// @Override
// public void onSuccess(Response createIndexResponse) {
// Assert.assertNotNull(createIndexResponse);
// Assert.assertTrue(createIndexResponse.ok());
// Response statsResponse = client.getCollectionStats(collectionName);
// if(statsResponse.ok()) {
// JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
// filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
// Assert.assertEquals(file.get("index_type"), Constants.indexType));
// }
// }
// @Override
// public void onFailure(Throwable t) {
// System.out.println(t.getMessage());
// }
// }, MoreExecutors.directExecutor()
// );
// }
@ -3,7 +3,8 @@ package com;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import io.milvus.client.exception.ClientSideMilvusException;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -12,175 +13,254 @@ import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
public class TestInsertEntities {
int dimension = Constants.dimension;
String tag = "tag";
int nb = Constants.nb;
Map<String, List> entities = Constants.defaultEntities;
Map<String, List> binaryEntities = Constants.defaultBinaryEntities;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertEntitiesCollectionNotExisted(MilvusClient client, String collectionName) {
String collectionNameNew = collectionName + "_";
InsertParam insertParam = new InsertParam.Builder(collectionNameNew)
InsertResponse res = client.insert(insertParam);
InsertParam insertParam = Utils.genInsertParam(collectionNameNew);
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// case-02
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
public void testInsertEntitiesWithoutConnect(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
InsertParam insertParam = Utils.genInsertParam(collectionName);
// case-03
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testInsertEntities(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
Response res_flush = client.flush(collectionName);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
List<Long> vectorIds = client.insert(insertParam);
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
Assert.assertEquals(vectorIds.size(), nb);
Assert.assertEquals(client.countEntities(collectionName), nb);
// case-04
@Test(dataProvider = "IdCollection", dataProviderClass = MainClass.class)
public void testInsertEntityWithIds(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> entityIds = LongStream.range(0, nb).boxed().collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(collectionName)
InsertResponse res = client.insert(insertParam);
Response res_flush = client.flush(collectionName);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
List<Long> vectorIds = client.insert(insertParam);
// Assert ids and collection row count
Assert.assertEquals(res.getEntityIds(), entityIds);
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
Assert.assertEquals(vectorIds, entityIds);
Assert.assertEquals(client.countEntities(collectionName), nb);
@Test(dataProvider = "IdCollection", dataProviderClass = MainClass.class)
// case-05
@Test(dataProvider = "IdCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertEntityWithInvalidIds(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> entityIds = LongStream.range(0, nb+1).boxed().collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).withEntityIds(entityIds).build();
InsertResponse res = client.insert(insertParam);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-06
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertEntityWithInvalidDimension(MilvusClient client, String collectionName) {
List<List<Float>> vectors = Utils.genVectors(nb, dimension+1, true);
List<Map<String,Object>> entities = Utils.genDefaultEntities(dimension+1,nb,vectors);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(entities).build();
InsertResponse res = client.insert(insertParam);
Map<String, List> entities = Utils.genDefaultEntities(nb,vectors);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-07
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertEntityWithInvalidVectors(MilvusClient client, String collectionName) {
List<Map<String,Object>> invalidEntities = Utils.genDefaultEntities(dimension,nb,new ArrayList<>());
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(invalidEntities).build();
InsertResponse res = client.insert(insertParam);
Map<String, List> entities = Utils.genDefaultEntities(nb,new ArrayList<>());
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
// ----------------------------- partition cases in Insert ---------------------------------
// Add vectors into collection with given tag
// case-08: Add vectors into collection with given tag
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testInsertEntityPartition(MilvusClient client, String collectionName) {
Response createpResponse = client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).withPartitionTag(tag).build();
InsertResponse res = client.insert(insertParam);
Response res_flush = client.flush(collectionName);
client.createPartition(collectionName, tag);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
// Assert collection row count
Response statsResponse = client.getCollectionStats(collectionName);
if(statsResponse.ok()) {
JSONArray partitionsJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "partitions");
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
Assert.assertEquals(obj.get("row_count"), nb);
Assert.assertEquals(obj.get("tag"), tag);
String stats = client.getCollectionStats(collectionName);
JSONArray partitionsJsonArray = Utils.parseJsonArray(stats, "partitions");
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
Assert.assertEquals(obj.get("row_count"), nb);
Assert.assertEquals(obj.get("tag"), tag);
// Add vectors into collection, which tag not existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// case-09: Add vectors into collection, which tag not existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertEntityPartitionTagNotExisted(MilvusClient client, String collectionName) {
Response createpResponse = client.createPartition(collectionName, tag);
String tag = RandomStringUtils.randomAlphabetic(10);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).withPartitionTag(tag).build();
InsertResponse res = client.insert(insertParam);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
// Binary tests
// case-10: Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testInsertEntityPartitionABinary(MilvusClient client, String collectionName) {
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).withPartitionTag(tag).build();
InsertResponse res = client.insert(insertParam);
Response res_flush = client.flush(collectionName);
// Assert collection row count
Response statsResponse = client.getCollectionStats(collectionName);
if(statsResponse.ok()) {
JSONArray partitionsJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "partitions");
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
Assert.assertEquals(obj.get("tag"), tag);
Assert.assertEquals(obj.get("row_count"), nb);
client.createPartition(collectionName, tag);
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors)
// Assert collection row count
String stats = client.getCollectionStats(collectionName);
JSONArray partitionsJsonArray = Utils.parseJsonArray(stats, "partitions");
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
Assert.assertEquals(obj.get("tag"), tag);
Assert.assertEquals(obj.get("row_count"), nb);
// case-11
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testInsertEntityBinary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse res = client.insert(insertParam);
Response res_flush = client.flush(collectionName);
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
List<Long> vectorIds = client.insert(insertParam);
// Assert collection row count
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
Assert.assertEquals(vectorIds.size(), nb);
Assert.assertEquals(client.countEntities(collectionName), nb);
// case-12
@Test(dataProvider = "BinaryIdCollection", dataProviderClass = MainClass.class)
public void testInsertBinaryEntityWithIds(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> entityIds = LongStream.range(0, nb).boxed().collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).withEntityIds(entityIds).build();
InsertResponse res = client.insert(insertParam);
Response res_flush = client.flush(collectionName);
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors)
List<Long> vectorIds = client.insert(insertParam);
// Assert collection row count
Assert.assertEquals(entityIds, res.getEntityIds());
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
Assert.assertEquals(entityIds, vectorIds);
Assert.assertEquals(client.countEntities(collectionName), nb);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// case-13
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertBinaryEntityWithInvalidIds(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> invalidEntityIds = LongStream.range(0, nb+1).boxed().collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).withEntityIds(invalidEntityIds).build();
InsertResponse res = client.insert(insertParam);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, binaryEntities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, binaryEntities.get(Constants.floatFieldName))
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, binaryEntities.get(Constants.binaryVectorFieldName))
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// case-14
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testInsertBinaryEntityWithInvalidDimension(MilvusClient client, String collectionName) {
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension-1);
List<Map<String,Object>> binaryEntities = Utils.genDefaultBinaryEntities(dimension-1,nb,vectorsBinary);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(binaryEntities).build();
InsertResponse res = client.insert(insertParam);
List<ByteBuffer> vectors = Utils.genBinaryVectors(nb, dimension-1);
Map<String, List> entities = Utils.genDefaultBinaryEntities(nb,vectors);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, entities.get(Constants.binaryVectorFieldName));
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
// public void testAsyncInsert(MilvusClient client, String collectionName) {
// InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
// ListenableFuture<InsertResponse> insertResFuture = client.insertAsync(insertParam);
// Futures.addCallback(
// insertResFuture, new FutureCallback<InsertResponse>() {
// @Override
// public void onSuccess(InsertResponse insertResponse) {
// Assert.assertNotNull(insertResponse);
// Assert.assertTrue(insertResponse.ok());
// Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
// }
// @Override
// public void onFailure(Throwable t) {
// System.out.println(t.getMessage());
// }
// }, MoreExecutors.directExecutor()
// );
// }
@ -1,4 +1,4 @@
package com;
package com;//package com;
//import io.milvus.client.*;
//import org.apache.commons.lang3.RandomStringUtils;
@ -1,4 +1,4 @@
package com;
package com;//package com;
//import io.milvus.client.*;
//import org.apache.commons.cli.*;
package com;
import io.milvus.client.*;
import io.milvus.client.exception.ServerSideMilvusException;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
public class TestPartition {
int dimension = 128;
// ----------------------------- create partition cases in ---------------------------------
@ -15,22 +15,19 @@ public class TestPartition {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testCreatePartition(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
client.createPartition(collectionName, tag);
Assert.assertEquals(client.hasPartition(collectionName, tag), true);
// show partitions
List<String> partitions = client.listPartitions(collectionName).getPartitionList();
List<String> partitions = client.listPartitions(collectionName);
// create partition, tag name existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testCreatePartitionTagNameExisted(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
Response createpResponseNew = client.createPartition(collectionName, tag);
assert (!createpResponseNew.ok());
client.createPartition(collectionName, tag);
client.createPartition(collectionName, tag);
// ----------------------------- has partition cases in ---------------------------------
@ -38,23 +35,17 @@ public class TestPartition {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testHasPartitionTagNameNotExisted(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
client.createPartition(collectionName, tag);
String tagNew = RandomStringUtils.randomAlphabetic(10);
HasPartitionResponse haspResponse = client.hasPartition(collectionName, tagNew);
assert (haspResponse.ok());
Assert.assertFalse(client.hasPartition(collectionName, tagNew));
// has partition, tag name existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testHasPartitionTagNameExisted(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
HasPartitionResponse haspResponse = client.hasPartition(collectionName, tag);
assert (haspResponse.ok());
client.createPartition(collectionName, tag);
Assert.assertTrue(client.hasPartition(collectionName, tag));
// ----------------------------- drop partition cases in ---------------------------------
@ -63,62 +54,47 @@ public class TestPartition {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testDropPartition(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponseNew = client.createPartition(collectionName, tag);
assert (createpResponseNew.ok());
Response response = client.dropPartition(collectionName, tag);
assert (response.ok());
client.createPartition(collectionName, tag);
client.dropPartition(collectionName, tag);
// show partitions
int length = client.listPartitions(collectionName).getPartitionList().size();
int length = client.listPartitions(collectionName).size();
// _default
Assert.assertEquals(length, 1);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDropPartitionDefault(MilvusClient client, String collectionName) {
String tag = "_default";
Response createpResponseNew = client.createPartition(collectionName, tag);
assert (!createpResponseNew.ok());
// show partitions
// System.out.println(client.listPartitions(collectionName).getPartitionList());
// int length = client.listPartitions(collectionName).getPartitionList().size();
// // _default
// Assert.assertEquals(length, 1);
client.createPartition(collectionName, tag);
// drop a partition repeat created before, drop by partition name
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDropPartitionRepeat(MilvusClient client, String collectionName) throws InterruptedException {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
Response response = client.dropPartition(collectionName, tag);
assert (response.ok());
Response newResponse = client.dropPartition(collectionName, tag);
assert (!newResponse.ok());
client.createPartition(collectionName, tag);
client.dropPartition(collectionName, tag);
client.dropPartition(collectionName, tag);
// drop a partition not created before
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDropPartitionNotExisted(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
client.createPartition(collectionName, tag);
String tagNew = RandomStringUtils.randomAlphabetic(10);
Response response = client.dropPartition(collectionName, tagNew);
client.dropPartition(collectionName, tagNew);
// drop a partition not created before
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testDropPartitionTagNotExisted(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
client.createPartition(collectionName, tag);
String newTag = RandomStringUtils.randomAlphabetic(10);
Response response = client.dropPartition(collectionName, newTag);
client.dropPartition(collectionName, newTag);
// ----------------------------- show partitions cases in ---------------------------------
@ -127,27 +103,22 @@ public class TestPartition {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testShowPartitions(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
ListPartitionsResponse response = client.listPartitions(collectionName);
assert (response.getResponse().ok());
client.createPartition(collectionName, tag);
List<String> partitions = client.listPartitions(collectionName);
// create multi partition, then show partitions
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testShowPartitionsMulti(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
client.createPartition(collectionName, tag);
String tagNew = RandomStringUtils.randomAlphabetic(10);
Response newCreatepResponse = client.createPartition(collectionName, tagNew);
assert (newCreatepResponse.ok());
ListPartitionsResponse response = client.listPartitions(collectionName);
assert (response.getResponse().ok());
client.createPartition(collectionName, tagNew);
List<String> partitions = client.listPartitions(collectionName);
@ -1,25 +1,25 @@
import io.milvus.client.*;
import org.testng.annotations.Test;
public class TestPing {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void testServerStatus(String host, int port) throws ConnectFailedException {
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
Response res = client.getServerStatus();
assert (res.ok());
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void testServerStatusWithoutConnected(MilvusClient client, String collectionName) throws ConnectFailedException {
Response res = client.getServerStatus();
assert (!res.ok());
//package com1;
//import io.milvus.client.*;
//import org.testng.annotations.Test;
//public class TestPing {
// @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
// public void testServerStatus(String host, int port) throws ConnectFailedException {
// System.out.println("Host: "+host+", Port: "+port);
// MilvusClient client = new MilvusGrpcClient();
// ConnectParam connectParam = new ConnectParam.Builder()
// .withHost(host)
// .withPort(port)
// .build();
// client.connect(connectParam);
// Response res = client.getServerStatus();
// assert (res.ok());
// }
// @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// public void testServerStatusWithoutConnected(MilvusClient client, String collectionName) throws ConnectFailedException {
// Response res = client.getServerStatus();
// assert (!res.ok());
// }
//package com;
package com;//package com;
//import io.milvus.client.*;
//import org.apache.commons.lang3.RandomStringUtils;
package com;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import io.milvus.client.exception.InvalidDsl;
import io.milvus.client.exception.ServerSideMilvusException;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TestSearchEntities {
int small_nb = 10;
int n_probe = 20;
int top_k = 10;
int nq = 5;
int top_k = Constants.topk;
int nq = Constants.nq;
List<List<Float>> queryVectors = Constants.vectors.subList(0, nq);
List<ByteBuffer> queryVectorsBinary = Constants.vectorsBinary.subList(0, nq);
public String dsl = Constants.searchParam;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public String floatDsl = Constants.searchParam;
public String binaryDsl = Constants.binarySearchParam;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchCollectionNotExisted(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
SearchParam searchParam = SearchParam.create(collectionNameNew).setDsl(floatDsl);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchCollectionEmpty(MilvusClient client, String collectionName) {
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
SearchResult res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultIdsList().size(), 0);
// # 3429
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchCollection(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
SearchResult res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultIdsList().size(), Constants.nq);
Assert.assertEquals(res_search.getResultDistancesList().size(), Constants.nq);
Assert.assertEquals(res_search.getResultIdsList().get(0).size(), Constants.topk);
@ -56,13 +57,11 @@ public class TestSearchEntities {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchDistance(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
SearchResult res_search = client.search(searchParam);
for (int i = 0; i < Constants.nq; i++) {
double distance = res_search.getResultDistancesList().get(i).get(0);
Assert.assertEquals(distance, 0.0, Constants.epsilon);
@ -71,14 +70,12 @@ public class TestSearchEntities {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchDistanceIP(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
String dsl = Utils.setSearchParam("IP", Constants.vectors.subList(0, nq), top_k, n_probe);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
String dsl = Utils.setSearchParam(MetricType.IP, queryVectors, top_k, n_probe);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
SearchResult res_search = client.search(searchParam);
for (int i = 0; i < Constants.nq; i++) {
double distance = res_search.getResultDistancesList().get(i).get(0);
Assert.assertEquals(distance, 1.0, Constants.epsilon);
@ -91,174 +88,359 @@ public class TestSearchEntities {
List<String> queryTags = new ArrayList<>();
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
InsertParam insertParam = Utils.genInsertParam(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build();
SearchResponse res_search = client.search(searchParam);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl).setPartitionTags(queryTags);
SearchResult res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultDistancesList().size(), 0);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchPartitionNotExited(MilvusClient client, String collectionName) {
public void testSearchPartitionNotExisted(MilvusClient client, String collectionName) {
String tag = Utils.genUniqueStr("tag");
String tagNew = Utils.genUniqueStr("tagNew");
List<String> queryTags = new ArrayList<>();
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
assert (res.getResponse().ok());
Map<String, List> entities = Constants.defaultEntities;
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build();
SearchResponse res_search = client.search(searchParam);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl).setPartitionTags(queryTags);
SearchResult res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultDistancesList().size(), 0);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchInvalidNProbe(MilvusClient client, String collectionName) {
int n_probe_new = -1;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
Response res_create = client.createIndex(index);
String dsl = Utils.setSearchParam(Constants.defaultMetricType, Constants.vectors.subList(0, nq), top_k, n_probe_new);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
Index index = Index
.create(collectionName, Constants.floatVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
String dsl = Utils.setSearchParam(Constants.defaultMetricType, queryVectors, top_k, n_probe_new);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchCountLessThanTopK(MilvusClient client, String collectionName) {
int top_k_new = 100;
int nb = 50;
List<Map<String,Object>> entities = Utils.genDefaultEntities(Constants.dimension, nb, Utils.genVectors(nb, Constants.dimension, false));
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(entities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
Map<String,List> entities = Utils.genDefaultEntities(nb, Utils.genVectors(nb, Constants.dimension, false));
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
String dsl = Utils.setSearchParam(Constants.defaultMetricType, Constants.vectors.subList(0, nq), top_k_new, n_probe);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
String dsl = Utils.setSearchParam(Constants.defaultMetricType, queryVectors, top_k_new, n_probe);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
SearchResult res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultIdsList().size(), Constants.nq);
Assert.assertEquals(res_search.getResultDistancesList().size(), Constants.nq);
Assert.assertEquals(res_search.getResultIdsList().get(0).size(), nb);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchInvalidTopK(MilvusClient client, String collectionName) {
int top_k = -1;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
InsertResponse res = client.insert(insertParam);
List<Long> ids = res.getEntityIds();
InsertParam insertParam = Utils.genInsertParam(collectionName);
// Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
// Response res_create = client.createIndex(index);
String dsl = Utils.setSearchParam(Constants.defaultMetricType, Constants.vectors.subList(0, nq), top_k, n_probe);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchResponse res_search = client.search(searchParam);
String dsl = Utils.setSearchParam(Constants.defaultMetricType, queryVectors, top_k, n_probe);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
// // Binary tests
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// public void testSearchCollectionNotExistedBinary(MilvusClient client, String collectionName) {
// String collectionNameNew = Utils.genUniqueStr(collectionName);
// SearchParam searchParam = new SearchParam.Builder(collectionNameNew)
// .withBinaryVectors(queryVectorsBinary)
// .withParamsInJson(searchParamStr)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (!res_search.getResponse().ok());
// }
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// public void test_search_index_IVFLAT_binary(MilvusClient client, String collectionName) {
// IndexType indexType = IndexType.IVFLAT;
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
// InsertResponse res = client.insert(insertParam);
// client.flush(collectionName);
// Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
// client.createIndex(index);
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withBinaryVectors(queryVectorsBinary)
// .withParamsInJson(searchParamStr)
// .withTopK(top_k).build();
// List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
// Assert.assertEquals(res_search.size(), nq);
// Assert.assertEquals(res_search.get(0).size(), top_k);
// }
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// public void test_search_ids_IVFLAT_binary(MilvusClient client, String collectionName) {
// IndexType indexType = IndexType.IVFLAT;
// List<Long> vectorIds;
// vectorIds = Stream.iterate(0L, n -> n)
// .limit(nb)
// .collect(Collectors.toList());
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).withVectorIds(vectorIds).build();
// InsertResponse res = client.insert(insertParam);
// Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
// client.createIndex(index);
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withBinaryVectors(queryVectorsBinary)
// .withParamsInJson(searchParamStr)
// .withTopK(top_k).build();
// List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
// Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L);
// }
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// public void test_search_partition_not_existed_binary(MilvusClient client, String collectionName) {
// IndexType indexType = IndexType.IVFLAT;
// String tag = RandomStringUtils.randomAlphabetic(10);
// Response createpResponse = client.createPartition(collectionName, tag);
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
// InsertResponse res = client.insert(insertParam);
// String tagNew = RandomStringUtils.randomAlphabetic(10);
// List<String> queryTags = new ArrayList<>();
// queryTags.add(tagNew);
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withBinaryVectors(queryVectorsBinary)
// .withParamsInJson(searchParamStr)
// .withPartitionTags(queryTags)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (!res_search.getResponse().ok());
// Assert.assertEquals(res_search.getQueryResultsList().size(), 0);
// }
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// public void test_search_invalid_n_probe_binary(MilvusClient client, String collectionName) {
// int n_probe_new = 0;
// String searchParamStrNew = Utils.setSearchParam(n_probe_new);
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
// client.insert(insertParam);
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withBinaryVectors(queryVectorsBinary)
// .withParamsInJson(searchParamStrNew)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (res_search.getResponse().ok());
// }
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
// public void test_search_invalid_top_k_binary(MilvusClient client, String collectionName) {
// int top_k_new = 0;
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
// client.insert(insertParam);
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withBinaryVectors(queryVectorsBinary)
// .withParamsInJson(searchParamStr)
// .withTopK(top_k_new).build();
// SearchResponse res_search = client.search(searchParam);
// assert (!res_search.getResponse().ok());
// }
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchMultiMust(MilvusClient client, String collectionName){
JSONObject vectorParam = Utils.genVectorParam(Constants.defaultMetricType, Constants.vectors.subList(0,nq), top_k, n_probe);
JSONObject boolParam = new JSONObject();
JSONObject mustParam = new JSONObject();
JSONArray jsonArray = new JSONArray();
JSONObject mustParam1 = new JSONObject();
mustParam1.put("must", jsonArray);
JSONArray jsonArray1 = new JSONArray();
mustParam.put("must", jsonArray1);
boolParam.put("bool", mustParam);
InsertParam insertParam = Utils.genInsertParam(collectionName);
String dsl = boolParam.toJSONString();
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
SearchResult resSearch = client.search(searchParam);
Assert.assertEquals(resSearch.getResultIdsList().size(), Constants.nq);
Assert.assertEquals(resSearch.getResultDistancesList().size(), Constants.nq);
Assert.assertEquals(resSearch.getResultIdsList().get(0).size(), Constants.topk);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = InvalidDsl.class)
public void testSearchMultiVectors(MilvusClient client, String collectionName) {
String dsl = String.format(
"{\"bool\": {"
+ "\"must\": [{"
+ " \"range\": {"
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
+ " }},{"
+ " \"vector\": {"
+ " \"float_vector\": {"
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
+ " }}},{"
+ " \"vector\": {"
+ " \"float_vector\": {"
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}\"\n"
+ " }}}]}}",
top_k, queryVectors, top_k, queryVectors);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = InvalidDsl.class)
public void testSearchNoVectors(MilvusClient client, String collectionName) {
String dsl = String.format(
"{\"bool\": {"
+ "\"must\": [{"
+ " \"range\": {"
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
+ " }},{"
+ " \"vector\": {"
+ " \"float_vector\": {"
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"params\": {\"nprobe\": 20}"
+ " }}}]}}",
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchVectorNotExisted(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
String dsl = String.format(
"{\"bool\": {"
+ "\"must\": [{"
+ " \"range\": {"
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
+ " }},{"
+ " \"vector\": {"
+ " \"float_vector\": {"
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
+ " }}}]}}",
top_k, new ArrayList<>());
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
SearchResult resSearch = client.search(searchParam);
Assert.assertEquals(resSearch.getResultIdsList().size(), 0);
Assert.assertEquals(resSearch.getResultDistancesList().size(), 0);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchVectorDifferentDim(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<List<Float>> query = Utils.genVectors(nq,64, false);
String dsl = String.format(
"{\"bool\": {"
+ "\"must\": [{"
+ " \"range\": {"
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
+ " }},{"
+ " \"vector\": {"
+ " \"float_vector\": {"
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
+ " }}}]}}",
top_k, query);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testAsyncSearch(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
ListenableFuture<SearchResult> searchResFuture = client.searchAsync(searchParam);
searchResFuture, new FutureCallback<SearchResult>() {
public void onSuccess(SearchResult searchResult) {
Assert.assertEquals(searchResult.getResultIdsList().size(), Constants.nq);
Assert.assertEquals(searchResult.getResultDistancesList().size(), Constants.nq);
Assert.assertEquals(searchResult.getResultIdsList().get(0).size(), Constants.topk);
Assert.assertEquals(searchResult.getFieldsMap().get(0).size(), top_k);
public void onFailure(Throwable t) {
}, MoreExecutors.directExecutor()
// Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchCollectionNotExistedBinary(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
SearchParam searchParam = SearchParam.create(collectionNameNew).setDsl(binaryDsl);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testSearchCollectionBinary(MilvusClient client, String collectionName) {
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
List<String> vectorsToSearch = vectors.subList(0, Constants.nq)
.stream().map(byteBuffer -> Arrays.toString(byteBuffer.array()))
String dsl = String.format(
"{\"bool\": {"
+ "\"must\": [{"
+ " \"vector\": {"
+ " \"binary_vector\": {"
+ " \"topk\": %d, \"metric_type\": \"JACCARD\", \"type\": \"binary\", \"query\": %s, \"params\": {\"nprobe\": 20}"
+ " }}}]}}",
top_k, vectorsToSearch.toString());
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
SearchResult resSearch = client.search(searchParam);
Assert.assertEquals(resSearch.getResultIdsList().size(), nq);
Assert.assertEquals(resSearch.getResultDistancesList().size(), nq);
Assert.assertEquals(resSearch.getResultIdsList().get(0).size(), top_k);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testSearchIVFLATBinary(MilvusClient client, String collectionName) {
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
List<Long> entityIds = client.insert(insertParam);
Index index = Index
.create(collectionName, Constants.binaryVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
SearchParam searchParam = SearchParam
.setDsl(Utils.setBinarySearchParam(Constants.defaultBinaryMetricType, vectors.subList(0, Constants.nq), Constants.topk, n_probe));
SearchResult resSearch = client.search(searchParam);
Assert.assertEquals(resSearch.getResultIdsList().size(), nq);
Assert.assertEquals(resSearch.getResultIdsList().get(0).size(), top_k);
Assert.assertEquals(resSearch.getResultIdsList().get(0).get(0), entityIds.get(0));
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testSearchPartitionNotExistedBinary(MilvusClient client, String collectionName) {
String tag = Utils.genUniqueStr("tag");
client.createPartition(collectionName, tag);
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
String tagNew = Utils.genUniqueStr("tagNew");
List<String> queryTags = new ArrayList<>();
SearchParam searchParam = SearchParam.create(collectionName).setDsl(binaryDsl).setPartitionTags(queryTags);
// #3656
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchInvalidNProbeBinary(MilvusClient client, String collectionName) {
int n_probe_new = 0;
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
Index index = Index
.create(collectionName, Constants.binaryVectorFieldName)
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
String dsl = Utils.setBinarySearchParam(Constants.defaultBinaryMetricType, vectors.subList(0, Constants.nq), top_k, n_probe_new);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
public void testSearchInvalidTopKBinary(MilvusClient client, String collectionName) {
int top_k = -1;
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
String dsl = Utils.setBinarySearchParam(Constants.defaultBinaryMetricType, vectors.subList(0, Constants.nq), top_k, n_probe);
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
package com;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import io.milvus.client.*;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomStringUtils;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.testng.Assert;
public class Utils {
@ -27,14 +27,14 @@ public class Utils {
public static List<List<Float>> genVectors(int vectorCount, int dimension, boolean norm) {
List<List<Float>> vectors = new ArrayList<>();
Random random = new Random();
List<List<Float>> vectors = new ArrayList<>();
for (int i = 0; i < vectorCount; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; ++j) {
if (norm == true) {
if (norm) {
vector = normalize(vector);
@ -42,12 +42,12 @@ public class Utils {
return vectors;
static List<ByteBuffer> genBinaryVectors(long vectorCount, long dimension) {
static List<ByteBuffer> genBinaryVectors(int vectorCount, int dimension) {
Random random = new Random();
List<ByteBuffer> vectors = new ArrayList<>();
final long dimensionInByte = dimension / 8;
for (long i = 0; i < vectorCount; ++i) {
ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte);
List<ByteBuffer> vectors = new ArrayList<>(vectorCount);
final int dimensionInByte = dimension / 8;
for (int i = 0; i < vectorCount; ++i) {
ByteBuffer byteBuffer = ByteBuffer.allocate(dimensionInByte);
@ -57,10 +57,10 @@ public class Utils {
private static List<Map<String, Object>> genBaseFieldsWithoutVector(){
List<Map<String,Object>> fieldsList = new ArrayList<>();
Map<String, Object> intFields = new HashMap<>();
Map<String, Object> floatField = new HashMap<>();
@ -72,10 +72,10 @@ public class Utils {
List<Map<String, Object>> defaultFieldList = genBaseFieldsWithoutVector();
Map<String, Object> vectorField = new HashMap<>();
if (isBinary){
vectorField.put(Constants.fieldNameKey, Constants.binaryVectorFieldName);
}else {
vectorField.put(Constants.fieldNameKey, Constants.floatVectorFieldName);
JSONObject jsonObject = new JSONObject();
@ -86,55 +86,33 @@ public class Utils {
return defaultFieldList;
public static List<Map<String,Object>> genDefaultEntities(int dimension, int vectorCount, List<List<Float>> vectors){
List<Map<String,Object>> fieldsMap = genDefaultFields(dimension, false);
public static Map<String, List> genDefaultEntities(int vectorCount, List<List<Float>> vectors){
// Map<String,Object> fieldsMap = genDefaultFields(dimension, false);
Map<String, List> fieldsMap =new HashMap<>();
List<Long> intValues = new ArrayList<>(vectorCount);
List<Float> floatValues = new ArrayList<>(vectorCount);
for (int i = 0; i < vectorCount; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
for(Map<String,Object> field: fieldsMap){
String fieldType = field.get("field").toString();
switch (fieldType){
case "int64":
case "float":
case "float_vector":
return fieldsMap;
public static List<Map<String,Object>> genDefaultBinaryEntities(int dimension, int vectorCount, List<ByteBuffer> vectorsBinary){
List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
public static Map<String, List> genDefaultBinaryEntities(int vectorCount, List<ByteBuffer> vectorsBinary){
// List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
Map<String, List> binaryFieldsMap =new HashMap<>();
List<Long> intValues = new ArrayList<>(vectorCount);
List<Float> floatValues = new ArrayList<>(vectorCount);
// List<List<Float>> vectors = genVectors(vectorCount,dimension,false);
// List<ByteBuffer> binaryVectors = genBinaryVectors(vectorCount,dimension);
for (int i = 0; i < vectorCount; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
for(Map<String,Object> field: binaryFieldsMap){
String fieldType = field.get("field").toString();
switch (fieldType){
case "int64":
case "float":
case "binary_vector":
return binaryFieldsMap;
@ -147,7 +125,7 @@ public class Utils {
return indexParams;
public static String setSearchParam(String metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
static JSONObject genVectorParam(MetricType metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
JSONObject searchParam = new JSONObject();
JSONObject fieldParam = new JSONObject();
fieldParam.put("topk", topk);
@ -158,34 +136,53 @@ public class Utils {
tmpSearchParam.put("nprobe", nprobe);
fieldParam.put("params", tmpSearchParam);
JSONObject vectorParams = new JSONObject();
vectorParams.put(Constants.floatFieldName, fieldParam);
vectorParams.put(Constants.floatVectorFieldName, fieldParam);
searchParam.put("vector", vectorParams);
JSONObject param = new JSONObject();
JSONObject mustParam = new JSONObject();
JSONArray tmp = new JSONArray();
mustParam.put("must", tmp);
param.put("bool", mustParam);
return JSONObject.toJSONString(param);
return searchParam;
public static String setBinarySearchParam(String metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
static JSONObject genBinaryVectorParam(MetricType metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
JSONObject searchParam = new JSONObject();
JSONObject fieldParam = new JSONObject();
fieldParam.put("topk", topk);
fieldParam.put("metricType", metricType);
fieldParam.put("queryVectors", queryVectors);
fieldParam.put("metric_type", metricType);
List<List<Byte>> vectorsToSearch = new ArrayList<>();
for (ByteBuffer byteBuffer : queryVectors) {
byte[] b = new byte[byteBuffer.remaining()];
fieldParam.put("query", vectorsToSearch);
fieldParam.put("type", Constants.binaryVectorType);
JSONObject tmpSearchParam = new JSONObject();
tmpSearchParam.put("nprobe", nprobe);
fieldParam.put("params", tmpSearchParam);
JSONObject vectorParams = new JSONObject();
vectorParams.put(Constants.floatFieldName, fieldParam);
vectorParams.put(Constants.binaryVectorFieldName, fieldParam);
searchParam.put("vector", vectorParams);
return searchParam;
public static String setSearchParam(MetricType metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
JSONObject searchParam = genVectorParam(metricType, queryVectors, topk, nprobe);
JSONObject boolParam = new JSONObject();
JSONObject mustParam = new JSONObject();
mustParam.put("must", new JSONArray().add(searchParam));
JSONArray tmp = new JSONArray();
mustParam.put("must", tmp);
boolParam.put("bool", mustParam);
return JSONObject.toJSONString(searchParam);
return JSONObject.toJSONString(boolParam);
public static String setBinarySearchParam(MetricType metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
JSONObject searchParam = genBinaryVectorParam(metricType, queryVectors, topk, nprobe);
JSONObject boolParam = new JSONObject();
JSONObject mustParam = new JSONObject();
JSONArray tmp = new JSONArray();
mustParam.put("must", tmp);
boolParam.put("bool", mustParam);
return JSONObject.toJSONString(boolParam);
public static int getIndexParamValue(String indexParam, String key) {
@ -218,7 +215,7 @@ public class Utils {
public static List<Float> getVector(List<Map<String,Object>> entities, int i){
List<Float> vector = new ArrayList<>();
entities.forEach(entity -> {
if("float_vector".equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){
if(Constants.floatVectorFieldName.equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){
@ -239,4 +236,237 @@ public class Utils {
throw new RuntimeException("unsupported type");
public static InsertParam genInsertParam(String collectionName) {
Map<String, List> entities = Constants.defaultEntities;
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
return insertParam;
public static InsertParam genBinaryInsertParam(String collectionName) {
List<Long> intValues = new ArrayList<>(Constants.nb);
List<Float> floatValues = new ArrayList<>(Constants.nb);
for (int i = 0; i < Constants.nb; ++i) {
intValues.add((long) i);
floatValues.add((float) i);
InsertParam insertParam = InsertParam
.addField(Constants.intFieldName, DataType.INT64, intValues)
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, Utils.genBinaryVectors(Constants.nb, Constants.dimension));
return insertParam;
public static CollectionMapping genCreateCollectionMapping(String collectionName, Boolean autoId, Boolean isBinary) {
CollectionMapping cm = CollectionMapping.create(collectionName)
.addField(Constants.intFieldName, DataType.INT64)
.addField(Constants.floatFieldName, DataType.FLOAT)
.setParamsInJson(new JsonBuilder()
.param("segment_row_limit", Constants.segmentRowLimit)
.param("auto_id", autoId)
if (isBinary) {
cm.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, Constants.dimension);
} else {
cm.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, Constants.dimension);
return cm;
public static List<Long> initData(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Assert.assertEquals(client.countEntities(collectionName), Constants.nb);
return ids;
public static List<Long> initBinaryData(MilvusClient client, String collectionName) {
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
List<Long> ids = client.insert(insertParam);
Assert.assertEquals(client.countEntities(collectionName), Constants.nb);
return ids;
// public static Index genDefaultIndex(String collectionName, String fieldName, String indexType, String metricType, int nlist) {
// return new Index.Builder(collectionName, fieldName)
// .withParamsInJson(new JsonBuilder()
// .param("index_type", indexType)
// .param("metric_type", metricType)
// .indexParam("nlist", nlist)
// .build())
// .build();
// }
@ -0,0 +1,14 @@
@ -0,0 +1,13 @@
@ -0,0 +1,15 @@
FROM python:3.6.8-jessie
LABEL Name=megasearch_engine_test Version=0.0.1
COPY . /app
RUN apt-get update && apt-get install -y --no-install-recommends \
libc-dev build-essential && \
python3 -m pip install -r requirements.txt && \
apt-get remove --purge -y && \
rm -rf /var/lib/apt/lists/*
ENTRYPOINT [ "/app/docker-entrypoint.sh" ]
CMD [ "start" ]
## Requirements
## Requirements
* python 3.6.8+
* pip install -r requirements.txt
## How to Build Test Env
sudo docker pull registry.zilliz.com/milvus/milvus-test:v0.2
sudo docker run -it -v /home/zilliz:/home/zilliz -d registry.zilliz.com/milvus/milvus-test:v0.2
## How to Create Test Env docker in k8s
# 1. start milvus k8s pod
cd milvus-helm/charts/milvus
helm install --wait --timeout 300s \
--set image.repository=registry.zilliz.com/milvus/engine \
--set persistence.enabled=true \
--set image.tag=PR-3818-gpu-centos7-release \
--set image.pullPolicy=Always \
--set service.type=LoadBalancer \
-f ci/db_backend/mysql_gpu_values.yaml \
-f ci/filebeat/values.yaml \
-f test.yaml \
--namespace milvus \
milvus-ci-pr-3818-1-single-centos7-gpu .
# 2. remove milvus k8s pod
helm uninstall -n milvus milvus-test
# 3. check k8s pod status
kubectl get svc -n milvus -w milvus-test
# 4. login to pod
kubectl get pods --namespace milvus
kubectl exec -it milvus-test-writable-6cc49cfcd4-rbrns -n milvus bash
## How to Run Test cases
# Test level-1 cases
pytest . --level=1 --ip= --port=19530
# Test level-1 cases in 'test_connect.py' only
pytest test_connect.py --level=1
## How to list test cases
# List all cases
pytest --dry-run -qq
# Collect all cases with docstring
pytest --collect-only -qq
# Create test report with allure
pytest --alluredir=test_out . -q -v
allure serve test_out
## Contribution getting started
* Follow PEP-8 for naming and black for formatting.
import logging
import logging
import pdb
import json
import requests
import traceback
import utils
from milvus import Milvus
from utils import *
url_collections = "collections"
url_system = "system/"
class Request(object):
def __init__(self, url):
# logging.getLogger().error(url)
self._url = url
def _check_status(self, result):
# logging.getLogger().info(result.text)
if result.status_code not in [200, 201, 204]:
return False
if not result.text or "code" not in json.loads(result.text):
return True
elif json.loads(result.text)["code"] == 0:
return True
return False
def get(self, data=None):
res_get = requests.get(self._url, params=data)
return self._check_status(res_get), json.loads(res_get.text)["data"]
def get_with_body(self, data=None):
res_get = requests.get(self._url, data=json.dumps(data))
return self._check_status(res_get), json.loads(res_get.text)["data"]
def post(self, data):
res_post = requests.post(self._url, data=json.dumps(data))
if res_post.text:
return self._check_status(res_post), json.loads(res_post.text)
return self._check_status(res_post), res_post
def delete(self, data=None):
if data:
res_delete = requests.delete(self._url, data=json.dumps(data))
res_delete = requests.delete(self._url)
return self._check_status(res_delete), res_delete
def put(self, data=None):
if data:
res_put = requests.put(self._url, data=json.dumps(data))
res_put = requests.put(self._url)
return self._check_status(res_put), res_put
class MilvusClient(object):
def __init__(self, url):
self._url = url
def create_collection(self, collection_name, fields):
url = self._url+url_collections
r = Request(url)
fields.update({"collection_name": collection_name})
status, data = r.post(fields)
if status:
return data
return False
except Exception as e:
return False
def list_collections(self, offset=0, page_size=10):
url = self._url+url_collections+'?'+'offset='+str(offset)+'&page_size='+str(page_size)
r = Request(url)
status, data = r.get()
if status:
return data["collections"]
return False
except Exception as e:
return False
def has_collection(self, collection_name):
url = self._url+url_collections+'/'+collection_name
r = Request(url)
status, data = r.get()
if status:
return data
return False
except Exception as e:
return False
def drop_collection(self, collection_name):
url = self._url+url_collections+'/'+str(collection_name)
r = Request(url)
status, data = r.delete()
if status:
return data
return False
except Exception as e:
return False
def info_collection(self, collection_name):
url = self._url+url_collections+'/'+str(collection_name)
r = Request(url)
status, data = r.get()
if status:
return data
return False
except Exception as e:
return False
def stat_collection(self, collection_name):
url = self._url+url_collections+'/'+str(collection_name)
r = Request(url)
status, data = r.get(data={"info": "stat"})
if status:
return data
return False
except Exception as e:
return False
def count_collection(self, collection_name):
return self.stat_collection(collection_name)["row_count"]
def create_partition(self, collection_name, tag):
url = self._url+url_collections+'/'+collection_name+'/partitions'
r = Request(url)
create_params = {"partition_tag": tag}
status, data = r.post(create_params)
if status:
return data
return False
except Exception as e:
return False
def list_partitions(self, collection_name):
url = self._url+url_collections+'/'+collection_name+'/partitions'
r = Request(url)
status, data = r.get()
if status:
return data["partitions"]
return False
except Exception as e:
return False
def drop_partition(self, collection_name, tag):
url = self._url+url_collections+'/'+collection_name+'/partitions/'+tag;
r = Request(url)
status, data = r.delete()
if status:
return data
return False
except Exception as e:
return False
def flush(self, collection_names):
url = self._url+url_system+'/task'
r = Request(url)
flush_params = {
"flush": {"collection_names": collection_names}}
status, data = r.put(data=flush_params)
if status:
return data
return False
except Exception as e:
return False
def insert(self, collection_name, entities, tag=None):
url = self._url+url_collections+'/'+collection_name+'/entities'
r = Request(url)
insert_params = {"entities": entities}
if tag:
insert_params.update({"partition_tag": tag})
status, data = r.post(insert_params)
if status:
return data["ids"]
return False
except Exception as e:
return False
def delete(self, collection_name, ids):
url = self._url+url_collections+'/'+collection_name+'/entities'
r = Request(url)
delete_params = {"ids": ids}
status, data = r.delete(data=delete_params)
if status:
return data
return False
except Exception as e:
return False
method: get entities by ids
def get_entities(self, collection_name, ids):
ids = ','.join(str(i) for i in ids)
url = self._url+url_collections+'/'+collection_name+'/entities?ids='+ids
# url = self._url+url_collections+'/'+collection_name+'/entities'
r = Request(url)
status, data = r.get()
if status:
return data["entities"]
return False
except Exception as e:
return False
method: create index
def create_index(self, collection_name, field_name, index_params):
url = self._url+url_collections+'/'+collection_name+'/fields/'+field_name+'/indexes'
r = Request(url)
status, data = r.post(index_params)
if status:
return data
return False
except Exception as e:
return False
def drop_index(self, collection_name, field_name):
url = self._url+url_collections+'/'+collection_name+'/fields/'+field_name+'/indexes'
r = Request(url)
status, data = r.delete()
if status:
return data
return False
except Exception as e:
return False
def describe_index(self, collection_name, field_name):
info = self.info_collection(collection_name)
for field in info["fields"]:
if field["field_name"] == field_name:
return field["index_params"]
def search(self, collection_name, query_expr, fields=None):
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
r = Request(url)
search_params = {
"query": query_expr,
"fields": fields
status, data = r.get_with_body(search_params)
if status:
return data
return False
except Exception as e:
return False
method: drop all collections in db
def clear_db(self):
collections = self.list_collections(page_size=10000)
if collections:
for item in collections:
def system_cmd(self, cmd):
url = self._url+url_system+cmd
r = Request(url)
status, data = r.get()["reply"]
if status:
return data
return False
except Exception as e:
return False
import pdb
import pdb
import copy
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
import sklearn.preprocessing
import pytest
from utils import *
from constants import *
uid = "create_collection"
class TestCreateCollection:
The following cases are used to test `create_collection` function
def get_filter_field(self, request):
yield request.param
def get_vector_field(self, request):
yield request.param
def get_segment_row_limit(self, request):
yield request.param
def test_create_collection_segment_row_limit(self, client, get_segment_row_limit):
target: test create normal collection with different fields
method: create collection with diff segment_row_limit
expected: no exception raised
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_limit
assert client.create_collection(collection_name, fields)
assert client.has_collection(collection_name)
def test_create_collection_exceed_segment_row_limit(self, client):
target: test create normal collection with different fields
method: create collection with diff segment_row_limit
expected: no exception raised
segment_row_limit = 10000000
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = segment_row_limit
client.create_collection(collection_name, fields)
assert not client.has_collection(collection_name)
def test_create_collection_id(self, client):
target: test create id collection
method: create collection with auto_id false
expected: no exception raised
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["auto_id"] = False
# fields = gen_default_fields(auto_id=False)
client.create_collection(collection_name, fields)
assert client.has_collection(collection_name)
def _test_create_binary_collection(self, client):
collection_name = 'test_NRHgct0s'
fields = {'fields': [{'name': 'int64', 'type': 'INT64'},
{'name': 'float', 'type': 'FLOAT'},
{'name': 'binary_vector', 'type': 'BINARY_FLOAT', 'params': {'dim': 128}}],
'segment_row_limit': 1000, 'auto_id': True}
client.create_collection(collection_name, fields)
assert client.has_collection(collection_name)
import pytest
import pytest
import time
from utils import *
from constants import *
class TestDropCollection:
The following cases are used to test `drop_collection` function
def test_drop_collection(self, client, collection):
target: test delete collection created with correct params
method: create collection and then delete,
assert the value returned by delete method
expected: status ok, and no collection in collections
assert not client.has_collection(collection)
def test_drop_collection_not_existed(self, client):
target: test if collection not created
method: random a collection name, which not existed in db, assert the exception raised returned by drp_collection method
expected: raise exception
collection_name = gen_unique_str()
assert not client.drop_collection(collection_name)
" ",
"12 s",
" siede ",
"a".join("a" for i in range(256))
def get_invalid_collection_name(self, request):
yield request.param
def test_drop_collection_with_invalid_collection(self, client, get_invalid_collection_name):
target: test if collection not created
method: random a collection name, which not existed in db, assert the exception raised returned by drp_collection method
expected: raise exception
assert not client.drop_collection(get_invalid_collection_name)
import logging
import logging
import pytest
import time
import copy
from utils import *
from constants import *
uid = "info_collection"
class TestInfoBase:
The following cases are used to test `get_collection_info` function, no data in collection
def test_get_collection_info(self, client, collection):
target: test get collection info with normal collection
method: create collection with default fields and get collection info
expected: no exception raised, and value returned correct
info = client.info_collection(collection)
assert info['count'] == 0
assert info['auto_id'] == True
assert info['segment_row_limit'] == default_segment_row_limit
assert len(info["fields"]) == 3
for field in info['fields']:
if field['type'] == 'INT64':
assert field['name'] == default_int_field_name
if field['type'] == 'FLOAT':
assert field['name'] == default_float_field_name
if field['type'] == 'VECTOR_FLOAT':
assert field['name'] == default_float_vec_field_name
def test_get_collection_info_segment_row_limit(self, client, collection):
target: test get collection info with non-default segment row limit
method: create collection with non-default segment row limit and get collection info
expected: no exception raised
segment_row_limit = 4096
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = segment_row_limit
client.create_collection(collection_name, fields)
client.insert(collection, default_entities)
info = client.info_collection(collection_name)
assert info['segment_row_limit'] == segment_row_limit
def test_get_collection_info_id_collection(self, client, id_collection):
target: test get collection info with id collection
method: create id collection with auto_id=False and get collection info
expected: no exception raised
info = client.info_collection(id_collection)
assert info['count'] == 0
assert info['auto_id'] == False
assert info['segment_row_limit'] == default_segment_row_limit
assert len(info["fields"]) == 3
def test_get_collection_info_with_collection_not_existed(self, client):
target: test get collection info with not existed collection
method: call get collection info with random collection name which not in db
expected: not ok
collection_name = gen_unique_str(uid)
assert not client.info_collection(collection_name)
" ",
"12 s",
" siede ",
"a".join("a" for i in range(256))
def get_invalid_collection_name(self, request):
yield request.param
def test_get_collection_info_collection_name_invalid(self, client, get_invalid_collection_name):
collection_name = get_invalid_collection_name
assert not client.info_collection(collection_name)
def test_row_count_after_insert(self, client, collection):
target: test the change of collection row count after insert data
method: insert entities to collection and get collection info
expected: row count increase
info = client.info_collection(collection)
assert info['count'] == 0
assert client.insert(collection, default_entities)
info = client.info_collection(collection)
assert info['count'] == default_nb
def test_get_collection_info_after_index_created(self, client, collection):
target: test index of collection info after index created
method: create index and get collection info
expected: no exception raised
index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
res = client.create_index(collection, default_float_vec_field_name, index)
info = client.info_collection(collection)
for field in info['fields']:
if field['name'] == default_float_vec_field_name:
assert field['index_params'] == index
import logging
import logging
import pytest
import time
import copy
from utils import *
from constants import *
uid = "stats_collection"
class TestStatsBase:
The following cases are used to test `collection_stats` function
def test_get_collection_stats(self, client, collection):
target: get collections stats
method: call get_collection_stats with created collection
expected: status ok
client.insert(collection, default_entities)
stats = client.stat_collection(collection)
assert stats['row_count'] == default_nb
assert len(stats["partitions"]) == 1
assert stats["partitions"][0]["tag"] == default_partition_name
assert stats["partitions"][0]["row_count"] == default_nb
def test_get_collection_stats_collection_not_existed(self, client, collection):
target: get collection stats when collection not existed
method: call collection_stats with a random collection_name, which is not in db
expected: status not ok
collection_name = gen_unique_str(uid)
assert not client.stat_collection(collection_name)
" ",
"12 s",
" siede ",
"a".join("a" for i in range(256))
def get_invalid_collection_name(self, request):
yield request.param
def test_get_collection_stats_name_invalid(self, client, collection, get_invalid_collection_name):
target: get collection stats where collection name is invalid
method: call collection_stats with invalid collection_name
expected: status not ok
assert not client.stat_collection(get_invalid_collection_name)
def test_get_collection_stats_empty_collection(self, client, collection):
target: get collection stats where no entity in collection
method: call collection_stats with empty collection
expected: segment = []
stats = client.stat_collection(collection)
assert stats["row_count"] == 0
assert len(stats["partitions"]) == 1
assert stats["partitions"][0]["tag"] == default_partition_name
assert stats["partitions"][0]["row_count"] == 0
import logging
import logging
import pytest
import time
import copy
from utils import *
from constants import *
uid = "list_collection"
class TestListCollections:
The following cases are used to test `list_collections` function
def test_list_collections(self, client, collection):
target: test list collections
method: create collection, assert the value returned by list_collections method
expected: True
collections = map(lambda x: x['collection_name'], client.list_collections())
assert collection in collections
def test_list_collections_not_existed(self, client):
target: test if collection not created
method: random a collection name, which not existed in db, assert the value returned by list_collections method
expected: False
collection_name = gen_unique_str(uid)
collections = map(lambda x: x['collection_name'], client.list_collections())
assert collection_name not in collections
def test_list_collections_no_collection(self, client):
target: test list collections when no collection in db
method: delete all collections and list collections
expected: status is ok and len of result is 0
result = client.list_collections()
assert len(result) == 0
def test_list_collections_multi_collections(self, client, collection):
target: test list collections with multi collections
method: create multi collections and list them
expected: len of list results is equal collection nums
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
assert client.create_collection(collection_name, fields)
collections = map(lambda x: x['collection_name'], client.list_collections())
assert collection_name in collections
assert collection in collections
def test_list_collections_offset(self, client, collection):
target: test list collections with offset parameter
method: create multi collections and list them with offset
expected: first collection with offset=1 equal to second collection with offset=0
collection_num = 2
fields = copy.deepcopy(default_fields)
for i in range(collection_num):
collection_name = gen_unique_str(uid)
assert client.create_collection(collection_name, fields)
collections = list(map(lambda x: x['collection_name'], client.list_collections()))
collections_new = list(map(lambda x: x['collection_name'], client.list_collections(offset=1)))
assert collections[1] == collections_new[0]
def test_list_collections_page_size(self, client, collection):
target: test list collections with page_size parameter
method: create multi collections and list them with page_size
expected: collection num equal to page_size
collection_num = 6
page_size = 5
fields = copy.deepcopy(default_fields)
for i in range(collection_num):
collection_name = gen_unique_str(uid)
assert client.create_collection(collection_name, fields)
collections = list(map(lambda x: x['collection_name'], client.list_collections(page_size=page_size)))
c = list(map(lambda x: x['collection_name'], client.list_collections()))
assert len(collections) == page_size
import pdb
import pdb
import logging
import socket
import pytest
import requests
from utils import gen_unique_str
from client import MilvusClient
from utils import *
timeout = 60
dimension = 128
delete_timeout = 60
def pytest_addoption(parser):
parser.addoption("--ip", action="store", default="localhost")
# parser.addoption("--ip", action="store", default="")
parser.addoption("--service", action="store", default="")
parser.addoption("--port", action="store", default=19121)
parser.addoption("--tag", action="store", default="all", help="only run tests matching the tag.")
parser.addoption('--dry-run', action='store_true', default=False)
def pytest_configure(config):
# register an additional marker
"markers", "tag(name): mark test to run only matching the tag"
def pytest_runtest_setup(item):
tags = list()
for marker in item.iter_markers(name="tag"):
for tag in marker.args:
if tags:
cmd_tag = item.config.getoption("--tag")
if cmd_tag != "all" and cmd_tag not in tags:
pytest.skip("test requires tag in {!r}".format(tags))
def pytest_runtestloop(session):
if session.config.getoption('--dry-run'):
for item in session.items:
return True
def check_server_connection(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
connected = True
if ip and (ip not in ['localhost', '']):
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
except Exception as e:
print("Socket connnet failed: %s" % str(e))
connected = False
return connected
def args(request):
ip = request.config.getoption("--ip")
service_name = request.config.getoption("--service")
port = request.config.getoption("--port")
url = "http://%s:%s/" % (ip, port)
client = MilvusClient(url)
args = {"ip": ip, "port": port, "service_name": service_name, "url": url, "client": client}
return args
def client(request, args):
client = args["client"]
return client
def collection(request, client):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
create_params = gen_default_fields()
if not client.create_collection(collection_name, create_params):
except Exception as e:
def teardown():
assert client.has_collection(collection_name)
return collection_name
# customised id
def id_collection(request, client):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
client.create_collection(collection_name, gen_default_fields(auto_id=False))
except Exception as e:
def teardown():
assert client.has_collection(collection_name)
return collection_name
def binary_collection(request, client):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
client.create_collection(collection_name, gen_default_fields(binary=True))
except Exception as e:
def teardown():
assert client.has_collection(collection_name)
return collection_name
# customised id
def binary_id_collection(request, client):
ori_collection_name = getattr(request.module, "collection_id", "test")
collection_name = gen_unique_str(ori_collection_name)
client.create_collection(collection_name, gen_default_fields(auto_id=False, binary=True))
except Exception as e:
def teardown():
assert client.has_collection(collection_name)
return collection_name
import utils
import utils
default_fields = utils.gen_default_fields()
default_binary_fields = utils.gen_default_fields(binary=True)
default_entity = utils.gen_entities(1)
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
default_entities = utils.gen_entities(utils.default_nb)
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
set -e
set -e
if [ "$1" = 'start' ]; then
tail -f /dev/null
exec "$@"
import logging
import logging
import time
import pdb
import copy
import threading
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
uid = "test_delete"
class TestDeleteBase:
The following cases are used to test `insert` function
def get_simple_index(self, request, client):
if str(client.system_cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
def get_filter_field(self, request):
yield request.param
def get_vector_field(self, request):
yield request.param
def test_delete_entity_id_not_exised(self, client, collection):
target: test get entity, params entity_id not existed
method: get entity
expected: result empty
assert client.insert(collection, default_entity)
res_flush = client.flush([collection])
entities = client.get_entities(collection, 1)
assert entities
def test_delete_empty_collection(self, client, collection):
target: test delete entity, params collection_name not existed
method: add entity and delete
expected: status DELETED
status = client.delete(collection, ["0"])
assert status
def test_delete_entity_collection_not_existed(self, client, collection):
target: test delete entity, params collection_name not existed
method: add entity and delete
expected: error raised
collection_new = gen_unique_str()
status = client.delete(collection_new, ["0"])
assert not status
def test_insert_delete(self, client, collection):
target: test delete entity
method: add entities and delete
expected: no error raised
ids = client.insert(collection, default_entities)
delete_ids = [ids[0]]
status = client.delete(collection, delete_ids)
assert status
res_count = client.count_collection(collection)
assert res_count == default_nb - 1
import logging
import logging
import time
import pdb
import copy
import threading
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
uid = "test_get"
class TestGetBase:
The following cases are used to test `insert` function
def get_simple_index(self, request, client):
if str(client.system_cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
def get_filter_field(self, request):
yield request.param
def get_vector_field(self, request):
yield request.param
def test_get_entity_id_not_exised(self, client, collection):
target: test delete entity, params entity_id not existed
method: add entity and delete
expected: entities empty
ids = client.insert(collection, default_entity)
entities = client.get_entities(collection, [0,1])
assert entities
def test_get_empty_collection(self, client, collection):
target: test hry entity, params collection_name not existed
method: add entity and get
expected: entities empty
entities = client.get_entities(collection, [0])
assert entities
def test_get_entity_collection_not_existed(self, client, collection):
target: test get entity, params collection_name not existed
method: add entity and get
expected: code error
collection_new = gen_unique_str()
entities = client.get_entities(collection_new, [0])
assert not entities
def test_insert_get(self, client, collection):
target: test get entity
method: add entities and get
expected: entity returned
ids = client.insert(collection, default_entities)
delete_ids = [ids[0]]
entities = client.get_entities(collection, delete_ids)
assert len(entities) == 1
def test_insert_get_batch(self, client, collection):
target: test get entity
method: add entities and get
expected: entity returned
get_length = 10
ids = client.insert(collection, default_entities)
delete_ids = ids[:get_length]
entities = client.get_entities(collection, delete_ids)
assert len(entities) == get_length
import logging
import logging
import time
import pdb
import copy
import threading
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
uid = "test_insert"
class TestInsertBase:
The following cases are used to test `insert` function
def get_simple_index(self, request, client):
if str(client.system_cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
def get_filter_field(self, request):
yield request.param
def get_vector_field(self, request):
yield request.param
def test_insert_with_empty_entity(self, client, collection):
target: test add vectors with empty vectors list
method: set empty vectors list as add method params
expected: raises a Exception
assert not client.insert(collection, [])
def test_insert_with_entity(self, client, collection):
target: test add vectors with an entity
method: insert with an entity
expected: count correct
assert client.insert(collection, default_entity)
res_flush = client.flush([collection])
count = client.count_collection(collection)
assert count == 1
def test_insert_with_entities(self, client, collection):
target: test add vectors with entities
method: insert entities
expected: count correct
assert client.insert(collection, default_entities)
res_flush = client.flush([collection])
count = client.count_collection(collection)
assert count == default_nb
def test_insert_with_field_not_match(self, client, collection):
target: insert field not match with collection schema
method: pop a field(int64)
expected: insert failed
entity = copy.deepcopy(default_entity)
assert not client.insert(collection, entity)
def test_insert_with_tag_not_existed(self, client, collection):
target: test add vectors with an entity
method: insert an entity with tag, which not created
expected: insert failed
assert not client.insert(collection, default_entity, tag=default_tag)
def test_insert_with_tag_not_existed(self, client, collection):
target: test add vectors with an entity
method: insert an entity with tag
expected: count correct
client.create_partition(collection, default_tag)
assert client.insert(collection, default_entity, tag=default_tag)
res_flush = client.flush([collection])
count = client.count_collection(collection)
assert count == 1
class TestInsertID:
The following cases are used to test `insert` function
def get_simple_index(self, request, client):
if str(client.system_cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
def get_filter_field(self, request):
yield request.param
def get_vector_field(self, request):
yield request.param
def test_insert_with_entity_id_not_matched(self, client, id_collection):
target: test add vectors with an entity
method: insert with an entity
expected: insert failed
assert not client.insert(id_collection, default_entity)
def test_insert_with_entity(self, client, id_collection):
target: test add vectors with an entity, in a id_collection
method: insert with an entity, with customized ids
expected: insert success
entity = copy.deepcopy(default_entity)
entity[0].update({"__id": 1})
assert client.insert(id_collection, entity)
res_flush = client.flush([id_collection])
count = client.count_collection(id_collection)
assert count == 1
import logging
import logging
import pytest
import requests
from utils import *
from constants import *
uid = "test_search"
epsilon = 0.001
field_name = default_float_vec_field_name
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
default_query, default_query_vectors = gen_query_vectors(field_name, default_entities, default_top_k, default_nq)
def init_data(client, collection, nb=default_nb, partition_tags=None, auto_id=True):
insert_entities = default_entities if nb == default_nb else gen_entities(nb)
if partition_tags is None:
if auto_id:
ids = client.insert(collection, insert_entities)
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)])
if auto_id:
ids = client.insert(collection, insert_entities, partition_tag=partition_tags)
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
return insert_entities, ids
def init_binary_data(client, collection, nb=default_nb, partition_tags=None, auto_id=True):
Generate binary entities and insert to collection
if nb == default_nb:
insert_entities = default_binary_entities
insert_raw_vectors = default_raw_binary_vectors
insert_raw_vectors, insert_entities = gen_binary_entities(nb)
if partition_tags is None:
if auto_id:
ids = client.insert(collection, insert_entities)
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)])
if auto_id:
ids = client.insert(collection, insert_entities, partition_tag=partition_tags)
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
return insert_raw_vectors, insert_entities, ids
def check_id_result(results, ids):
ids_res = []
for result in results:
ids_res.extend([int(item['id']) for item in result])
for id in ids:
if int(id) not in ids_res:
return False
return True
class TestSearchBase:
The following cases are used to test `search` function
generate top-k params
params=[1, 10]
def get_top_k(self, request):
yield request.param
params=[1, 10]
def get_nq(self, request):
yield request.param
def get_simple_index(self, request, connect):
if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode")
return request.param
def test_search_flat(self, client, collection, get_nq, get_top_k):
target: test basic search function, all the search params is correct, change top-k value
method: search with the given vectors, check the result
expected: the length of result is top_k
top_k = get_top_k
nq = get_nq
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq)
data = client.search(collection, query)
res = data['result']
assert data['num'] == nq
assert len(res) == nq
assert len(res[0]) == top_k
assert float(res[0][0]['distance']) <= epsilon
assert check_id_result(res, ids[:nq])
def test_search_invalid_top_k(self, client, collection):
target: test search with invalid top_k that large than max_top_k
method: call search with invalid top_k
expected: exception
top_k = max_top_k + 1
nq = 1
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq)
assert not client.search(collection, query)
def test_search_fields(self, client, collection, ):
target: test search with field
method: call search with field and check return whether contain field value
expected: return field value
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, fields=[default_int_field_name])
res = data['result']
assert data['num'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == default_top_k
assert default_int_field_name in res[0][0]['entity'].keys()
def test_search_invalid_n_probe(self, client, collection, ):
target: test basic search function with invalid n_probe
method: call search function
expected: not ok
entities, ids = init_data(client, collection)
assert client.create_index(collection, default_float_vec_field_name, default_index)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq,
search_params={"nprobe": 0})
assert not client.search(collection, query)
def test_search_not_existed_collection(self, client, collection):
target: test basic search with not existed collection
method: call search function
expected: not ok
collection_name = gen_unique_str(uid)
assert not client.search(collection_name, default_query)
def test_search_empty_collection(self, client, collection):
target: test basic search function with empty collection
method: call search function
expected: return 0 entities
assert 0 == client.count_collection(collection)
data = client.search(collection, default_query)
res = data['result']
assert data['num'] == 0
assert len(res) == 0
" ",
"12 s",
" siede ",
"a".join("a" for i in range(256))
def get_invalid_collection_name(self, request):
yield request.param
def test_get_collection_stats_name_invalid(self, client, collection, get_invalid_collection_name):
target: test search when collection name is invalid
method: call search with invalid collection_name
expected: status not ok
collection_name = get_invalid_collection_name
assert not client.search(collection_name, default_query)
def test_search_invalid_format_query(self, client, collection):
target: test search with invalid format query
method: call search with invalid query string
expected: status not ok and url `/collections/xxx/entities` return correct
entities, ids = init_data(client, collection)
must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[]], "params": {"nprobe": 10}}}}
must_param["vector"][field_name]["metric_type"] = 'L2'
query = {
"bool": {
"must": [must_param]
assert not client.search(collection, query)
def test_search_with_invalid_metric_type(self, client, collection):
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l2")
assert not client.search(collection, query)
def test_search_binary_flat(self, client, binary_collection):
raw_vectors, binary_entities, ids = init_data(client, binary_collection)
query, query_vectors = gen_query_vectors(field_name, binary_entities, default_top_k, default_nq)
data = client.search(binary_collection, query)
res = data['result']
assert data['num'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == default_top_k
assert float(res[0][0]['distance']) <= epsilon
assert check_id_result(res, ids[:default_nq])
import pdb
import logging
import logging
import time
import random
from locust import User, events
from client import MilvusClient
url = ''
class MilvusTask(object):
def __init__(self, url):
self.request_type = "http"
self.m = MilvusClient(url)
# logging.getLogger().info(id(self.m))
def __getattr__(self, name):
func = getattr(self.m, name)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
total_time = int((time.time() - start_time) * 1000)
if result:
events.request_success.fire(request_type=self.request_type, name=name, response_time=total_time,
events.request_failure.fire(request_type=self.request_type, name=name, response_time=total_time,
exception=e, response_length=0)
except Exception as e:
total_time = int((time.time() - start_time) * 1000)
events.request_failure.fire(request_type=self.request_type, name=name, response_time=total_time,
exception=e, response_length=0)
return wrapper
@ -0,0 +1,14 @@
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_date_format = %Y-%m-%d %H:%M:%S
log_cli = true
log_level = 20
timeout = 360
markers =
level: test level
#level = 1
@ -0,0 +1,12 @@
@ -0,0 +1,25 @@
@ -0,0 +1,11 @@
@ -0,0 +1,4 @@
pytest . $@
pytest . $@
import logging
import time
import pdb
import copy
import threading
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
uid = "test_index"
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
class TestIndexBase:
The following cases are used to test `insert` function
def get_simple_index(self, request, client):
if str(client.system_cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support():
pytest.skip("CPU not support index_type: ivf_sq8h")
return request.param
def get_filter_field(self, request):
yield request.param
def get_vector_field(self, request):
yield request.param
def test_create_index(self, client, collection):
target: test create index interface
method: create collection and add entities in it, create index
expected: return search success
ids = client.insert(collection, default_entities)
res = client.create_index(collection, default_float_vec_field_name, default_index)
# get index info
res_info_index = client.describe_index(collection, default_float_vec_field_name)
assert res_info_index == default_index
def test_create_index_on_field_not_existed(self, client, collection):
target: test create index interface
method: create collection and add entities in it, create index on field not existed
expected: create failed
tmp_field_name = gen_unique_str()
ids = client.insert(collection, default_entities)
assert not client.create_index(collection, tmp_field_name, default_index)
def test_create_index_on_field(self, client, collection):
target: test create index interface
method: create collection and add entities in it, create index on other field
expected: create failed
tmp_field_name = "int64"
ids = client.insert(collection, default_entities)
assert not client.create_index(collection, tmp_field_name, default_index)
def test_create_index_collection_not_existed(self, client):
target: test create index interface when collection name not existed
method: create collection and add entities in it, create index
, make sure the collection name not in index
expected: create index failed
collection_name = gen_unique_str(uid)
assert not client.create_index(collection_name, default_float_vec_field_name, default_index)
def test_drop_index(self, client, collection):
target: test drop index interface
method: create collection and add entities in it, create index, call drop index
expected: return code 0, and default index param
# ids = connect.insert(collection, entities)
client.create_index(collection, default_float_vec_field_name, default_index)
client.drop_index(collection, default_float_vec_field_name)
res_info_index = client.describe_index(collection, default_float_vec_field_name)
assert not res_info_index
def test_drop_index_collection_not_existed(self, client):
target: test drop index interface when collection name not existed
method: create collection and add entities in it, create index
, make sure the collection name not in index, and then drop it
expected: return code not equals to 0, drop index failed
collection_name = gen_unique_str(uid)
assert not client.drop_index(collection_name, default_float_vec_field_name)
import time
import time
import random
import pdb
import threading
import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
class TestCreateBase:
The following cases are used to test `create_partition` function
def test_create_partition(self, client, collection):
target: test create partition, check status returned
method: call function: create_partition
expected: status ok
client.create_partition(collection, default_tag)
def test_create_partition_repeat(self, client, collection):
target: test create partition, check status returned
method: call function: create_partition
expected: status ok
client.create_partition(collection, default_tag)
ret = client.create_partition(collection, default_tag)
assert(not ret)
def test_create_partition_collection_not_existed(self, client):
target: test create partition, its owner collection name not existed in db, check status returned
method: call function: create_partition
expected: status not ok
collection_name = gen_unique_str()
ret = client.create_partition(collection_name, default_tag)
assert(not ret)
def test_create_partition_tag_name_None(self, client, collection):
target: test create partition, tag name set None, check status returned
method: call function: create_partition
expected: status ok
tag_name = None
ret = client.create_partition(collection, tag_name)
assert(not ret)
def test_create_different_partition_tags(self, client, collection):
target: test create partition twice with different names
method: call function: create_partition, and again
expected: status ok
client.create_partition(collection, default_tag)
tag_name = gen_unique_str()
client.create_partition(collection, tag_name)
ret = client.list_partitions(collection)
tag_list = []
for item in ret:
assert default_tag in tag_list
assert tag_name in tag_list
assert "_default" in tag_list
class TestShowBase:
The following cases are used to test `list_partitions` function
def test_list_partitions(self, client, collection):
target: test show partitions, check status and partitions returned
method: create partition first, then call function: list_partitions
expected: status ok, partition correct
client.create_partition(collection, default_tag)
ret = client.list_partitions(collection)
tag_list = []
for item in ret:
assert default_tag in tag_list
def test_list_partitions_no_partition(self, client, collection):
target: test show partitions with collection name, check status and partitions returned
method: call function: list_partitions
expected: status ok, partitions correct
res = client.list_partitions(collection)
assert len(res) == 1
class TestDropBase:
The following cases are used to test `drop_partition` function
def test_drop_partition(self, client, collection):
target: test drop partition, check status and partition if existed
method: create partitions first, then call function: drop_partition
expected: status ok, no partitions in db
client.create_partition(collection, default_tag)
client.drop_partition(collection, default_tag)
res = client.list_partitions(collection)
tag_list = []
for item in res:
assert default_tag not in tag_list
def test_drop_partition_tag_not_existed(self, client, collection):
target: test drop partition, but tag not existed
method: create partitions first, then call function: drop_partition
expected: status not ok
client.create_partition(collection, default_tag)
new_tag = "new_tag"
ret = client.drop_partition(collection, new_tag)
assert(not ret)
def test_drop_partition_tag_not_existed_A(self, client, collection):
target: test drop partition, but collection not existed
method: create partitions first, then call function: drop_partition
expected: status not ok
client.create_partition(collection, default_tag)
new_collection = gen_unique_str()
ret = client.drop_partition(new_collection, default_tag)
assert(not ret)
class TestNameInvalid(object):
"a".join("a" for i in range(256))
def get_tag_name(self, request):
yield request.param
"a".join("a" for i in range(256))
def get_collection_name(self, request):
yield request.param
def test_drop_partition_with_invalid_collection_name(self, client, collection, get_collection_name):
target: test drop partition, with invalid collection name, check status returned
method: call function: drop_partition
expected: status not ok
collection_name = get_collection_name
client.create_partition(collection, default_tag)
ret = client.drop_partition(collection_name, default_tag)
assert(not ret)
def test_drop_partition_with_invalid_tag_name(self, client, collection, get_tag_name):
target: test drop partition, with invalid tag name, check status returned
method: call function: drop_partition
expected: status not ok
tag_name = get_tag_name
client.create_partition(collection, default_tag)
ret = client.drop_partition(collection, tag_name)
assert(not ret)
def test_list_partitions_with_invalid_collection_name(self, client, collection, get_collection_name):
target: test show partitions, with invalid collection name, check status returned
method: call function: list_partitions
expected: status not ok
collection_name = get_collection_name
client.create_partition(collection, default_tag)
ret = client.list_partitions(collection_name)
assert(not ret)
import logging
import logging
import pytest
import pdb
from utils import *
__version__ = '0.11.0'
class TestPing:
def test_server_version(self, client):
target: test get the server version
method: call the server_version method after connected
expected: version should be the milvus version
res = client.system_cmd("version")
assert res == __version__
def test_server_status(self, client):
target: test get the server status
method: call the server_status method after connected
expected: status returned should be ok
msg = client.system_cmd("status")
assert msg
def test_server_cmd_with_params_version(self, client):
target: test cmd: version
method: cmd = "version" ...
expected: when cmd = 'version', return version of server;
cmd = "version"
msg = client.system_cmd(cmd)
assert msg == __version__
def test_server_cmd_with_params_others(self, client):
target: test cmd: lalala
method: cmd = "lalala" ...
expected: when cmd = 'version', return version of server;
cmd = "rm -rf test"
msg = client.system_cmd(cmd)
assert msg == default_unknown_cmd
import os
import os
import sys
import random
import pdb
import string
import struct
import logging
import time, datetime
import copy
import numpy as np
from sklearn import preprocessing
from milvus import Milvus, DataType
port = 19530
epsilon = 0.000001
namespace = "milvus"
default_flush_interval = 1
big_flush_interval = 1000
default_drop_interval = 3
default_dim = 128
default_nb = 1200
default_top_k = 10
default_nq = 1
max_top_k = 16384
max_partition_num = 256
default_segment_row_limit = 1000
default_server_segment_row_limit = 1024 * 512
default_float_vec_field_name = "float_vector"
default_binary_vec_field_name = "binary_vector"
default_int_field_name = "int64"
default_float_field_name = "float"
default_double_field_name = "double"
default_partition_name = "_default"
default_tag = "1970_01_01"
default_other_fields = ["INT64", "FLOAT"]
default_unknown_cmd = "Unknown command"
# TODO: disable RHNSW_SQ/PQ in 0.11.0
all_index_types = [
# "NSG",
default_index_params = [
{"nlist": 128},
{"nlist": 128},
{"nlist": 128},
{"nlist": 128},
{"nlist": 128, "m": 16},
{"M": 48, "efConstruction": 500},
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
{"n_trees": 50},
# {"M": 48, "efConstruction": 500, "PQM": 16},
# {"M": 48, "efConstruction": 500},
{"nlist": 128},
{"nlist": 128}
def index_cpu_not_support():
return ["IVF_SQ8_HYBRID"]
def binary_support():
return ["BIN_FLAT", "BIN_IVF_FLAT"]
def delete_support():
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
def ivf():
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
def binary_metrics():
def structure_metrics():
def l2(x, y):
return np.linalg.norm(np.array(x) - np.array(y))
def ip(x, y):
return np.inner(np.array(x), np.array(y))
def jaccard(x, y):
x = np.asarray(x, np.bool)
y = np.asarray(y, np.bool)
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())
def hamming(x, y):
x = np.asarray(x, np.bool)
y = np.asarray(y, np.bool)
return np.bitwise_xor(x, y).sum()
def tanimoto(x, y):
x = np.asarray(x, np.bool)
y = np.asarray(y, np.bool)
return -np.log2(np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum()))
def substructure(x, y):
x = np.asarray(x, np.bool)
y = np.asarray(y, np.bool)
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(y)
def superstructure(x, y):
x = np.asarray(x, np.bool)
y = np.asarray(y, np.bool)
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x)
def get_default_field_name(data_type):
if data_type == "VECTOR_FLOAT":
field_name = default_float_vec_field_name
elif data_type == "VECTOR_BINARY":
field_name = default_binary_vec_field_name
elif data_type == "INT64":
field_name = default_int_field_name
elif data_type == "FLOAT":
field_name = default_float_field_name
elif data_type == "DOUBLE":
field_name = default_double_field_name
return field_name
def reset_build_index_threshold(connect):
connect.set_config("engine", "build_index_threshold", 1024)
def disable_flush(connect):
connect.set_config("storage", "auto_flush_interval", big_flush_interval)
def enable_flush(connect):
# reset auto_flush_interval=1
connect.set_config("storage", "auto_flush_interval", default_flush_interval)
config_value = connect.get_config("storage", "auto_flush_interval")
assert config_value == str(default_flush_interval)
def gen_inaccuracy(num):
return num / 255.0
def gen_vectors(num, dim, is_normal=True):
vectors = [[random.random() for _ in range(dim)] for _ in range(num)]
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
return vectors.tolist()
def gen_binary_vectors(num, dim):
raw_vectors = []
binary_vectors = []
for i in range(num):
raw_vector = [random.randint(0, 1) for i in range(dim)]
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return raw_vectors, binary_vectors
def gen_binary_sub_vectors(vectors, length):
raw_vectors = []
binary_vectors = []
dim = len(vectors[0])
for i in range(length):
raw_vector = [0 for i in range(dim)]
vector = vectors[i]
for index, j in enumerate(vector):
if j == 1:
raw_vector[index] = 1
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return raw_vectors, binary_vectors
def gen_binary_super_vectors(vectors, length):
raw_vectors = []
binary_vectors = []
dim = len(vectors[0])
for i in range(length):
cnt_1 = np.count_nonzero(vectors[i])
raw_vector = [1 for i in range(dim)]
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return raw_vectors, binary_vectors
def gen_int_attr(row_num):
return [random.randint(0, 255) for _ in range(row_num)]
def gen_float_attr(row_num):
return [random.uniform(0, 255) for _ in range(row_num)]
def gen_unique_str(str_value=None):
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
return "test_" + prefix if str_value is None else str_value + "_" + prefix
def gen_single_filter_fields():
fields = []
for data_type in DataType:
if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
fields.append({"name": data_type.name, "type": data_type})
return fields
def gen_single_vector_fields():
fields = []
for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}}
return fields
def gen_default_fields(auto_id=True, binary=False):
fields = [
{"name": default_int_field_name, "type": "INT64"},
{"name": default_float_field_name, "type": "FLOAT"}
if binary is False:
field = {"name": default_float_vec_field_name, "type": "VECTOR_FLOAT",
"params": {"dim": default_dim}}
field = {"name": default_binary_vec_field_name, "type": "BINARY_FLOAT",
"params": {"dim": default_dim}}
default_fields = {
"fields": fields,
"segment_row_limit": default_segment_row_limit,
"auto_id": auto_id
return default_fields
def gen_entities(nb, is_normal=False):
entities = []
vectors = gen_vectors(nb, default_dim, is_normal)
for i in range(nb):
entity = {
"int64": i,
"float": float(i),
default_float_vec_field_name: vectors[i]
return entities
def gen_binary_entities(nb):
raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
entities = []
for i in range(nb):
entity = {
default_int_field_name: i,
default_float_field_name: float(i),
default_binary_vec_field_name: vectors
return raw_vectors, entities
def gen_entities_by_fields(fields, nb, dim):
entities = []
for field in fields:
if field["type"] in [DataType.INT32, DataType.INT64]:
field_value = [1 for i in range(nb)]
elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
field_value = [3.0 for i in range(nb)]
elif field["type"] == DataType.BINARY_VECTOR:
field_value = gen_binary_vectors(nb, dim)[1]
elif field["type"] == DataType.FLOAT_VECTOR:
field_value = gen_vectors(nb, dim)
field.update({"values": field_value})
return entities
def assert_equal_entity(a, b):
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
metric_type="L2", replace_vecs=None):
if rand_vector is True:
dimension = len(entities[0][default_float_vec_field_name][0])
query_vectors = gen_vectors(nq, dimension)
query_vectors = list(map(lambda x: x[default_float_vec_field_name], entities[:nq]))
if replace_vecs:
query_vectors = replace_vecs
must_param = {"vector": {field_name: {"topk": top_k, "values": query_vectors, "params": search_params}}}
must_param["vector"][field_name]["metric_type"] = metric_type
query = {
"bool": {
"must": [must_param]
return query, query_vectors
def update_query_expr(src_query, keep_old=True, expr=None):
tmp_query = copy.deepcopy(src_query)
if expr is not None:
if keep_old is not True:
return tmp_query
def gen_default_vector_expr(default_query):
return default_query["bool"]["must"][0]
def gen_default_term_expr(keyword="term", field="int64", values=None):
if values is None:
values = [i for i in range(default_nb // 2)]
expr = {keyword: {field: {"values": values}}}
return expr
def update_term_expr(src_term, terms):
tmp_term = copy.deepcopy(src_term)
for term in terms:
return tmp_term
def gen_default_range_expr(keyword="range", field="int64", ranges=None):
if ranges is None:
ranges = {"GT": 1, "LT": default_nb // 2}
expr = {keyword: {field: ranges}}
return expr
def update_range_expr(src_range, ranges):
tmp_range = copy.deepcopy(src_range)
for range in ranges:
return tmp_range
def gen_invalid_range():
range = [
{"range": 1},
{"range": {}},
{"range": []},
{"range": {"range": {"int64": {"GT": 0, "LT": default_nb // 2}}}}
return range
def gen_valid_ranges():
ranges = [
{"GT": 0, "LT": default_nb // 2},
{"GT": default_nb // 2, "LT": default_nb * 2},
{"GT": 0},
{"LT": default_nb},
{"GT": -1, "LT": default_top_k},
return ranges
def gen_invalid_term():
terms = [
{"term": 1},
{"term": []},
{"term": {}},
{"term": {"term": {"int64": {"values": [i for i in range(default_nb // 2)]}}}}
return terms
def add_field_default(default_fields, type=DataType.INT64, field_name=None):
tmp_fields = copy.deepcopy(default_fields)
if field_name is None:
field_name = gen_unique_str()
field = {
"name": field_name,
"type": type
return tmp_fields
def add_field(entities, field_name=None):
nb = len(entities[0]["values"])
tmp_entities = copy.deepcopy(entities)
if field_name is None:
field_name = gen_unique_str()
field = {
"name": field_name,
"type": DataType.INT64,
"values": [i for i in range(nb)]
return tmp_entities
def add_vector_field(entities, is_normal=False):
nb = len(entities[0]["values"])
vectors = gen_vectors(nb, default_dim, is_normal)
field = {
"name": gen_unique_str(),
"type": DataType.FLOAT_VECTOR,
"values": vectors
return entities
# def update_fields_metric_type(fields, metric_type):
# tmp_fields = copy.deepcopy(fields)
# if metric_type in ["L2", "IP"]:
# tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR
# else:
# tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR
# tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type
# return tmp_fields
def remove_field(entities):
del entities[0]
return entities
def remove_vector_field(entities):
del entities[-1]
return entities
def update_field_name(entities, old_name, new_name):
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["name"] == old_name:
item["name"] = new_name
return tmp_entities
def update_field_type(entities, old_name, new_name):
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["name"] == old_name:
item["type"] = new_name
return tmp_entities
def update_field_value(entities, old_type, new_value):
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["type"] == old_type:
for index, value in enumerate(item["values"]):
item["values"][index] = new_value
return tmp_entities
def add_vector_field(nb, dimension=default_dim):
field_name = gen_unique_str()
field = {
"name": field_name,
"type": DataType.FLOAT_VECTOR,
"values": gen_vectors(nb, dimension)
return field_name
def gen_segment_row_limits():
sizes = [
return sizes
def gen_invalid_ips():
ips = [
# "",
# "",
# "",
# "",
# "",
" ",
"12 s",
" siede ",
"a".join("a" for _ in range(256))
return ips
def gen_invalid_uris():
ip = None
uris = [
" ",
# invalid protocol
# "tc://%s:%s" % (ip, port),
# "tcp%s:%s" % (ip, port),
# # invalid port
# "tcp://%s:100000" % ip,
# "tcp://%s: " % ip,
# "tcp://%s:19540" % ip,
# "tcp://%s:-1" % ip,
# "tcp://%s:string" % ip,
# invalid ip
"tcp:// :19530",
# "tcp://" % port,
# "tcp://" % port,
# "tcp://" % port,
# "tcp://" % port,
# "tcp://" % port,
return uris
def gen_invalid_strs():
strings = [
" ",
# "",
# None,
"12 s",
" siede ",
"a".join("a" for i in range(256))
return strings
def gen_invalid_field_types():
field_types = [
# 1,
# 0,
"a".join("a" for i in range(256))
return field_types
def gen_invalid_metric_types():
metric_types = [
"a".join("a" for i in range(256))
return metric_types
def gen_invalid_ints():
int_values = [
# 1.0,
[1, 2, 3],
" ",
"a".join("a" for i in range(256))
return int_values
def gen_invalid_params():
params = [
# None,
[1, 2, 3],
" ",
return params
def gen_invalid_vectors():
invalid_vectors = [
[1, 2],
[" "],
(1, 2),
{"a": 1},
" ",
" siede ",
"a".join("a" for i in range(256))
return invalid_vectors
def gen_invaild_search_params():
invalid_search_key = 100
search_params = []
for index_type in all_index_types:
if index_type == "FLAT":
search_params.append({"index_type": index_type, "search_params": {"invalid_key": invalid_search_key}})
if index_type in delete_support():
for nprobe in gen_invalid_params():
ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
elif index_type in ["HNSW", "RHNSW_PQ", "RHNSW_SQ"]:
for ef in gen_invalid_params():
hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
elif index_type == "NSG":
for search_length in gen_invalid_params():
nsg_search_param = {"index_type": index_type, "search_params": {"search_length": search_length}}
search_params.append({"index_type": index_type, "search_params": {"invalid_key": 100}})
elif index_type == "ANNOY":
for search_k in gen_invalid_params():
if isinstance(search_k, int):
annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}}
return search_params
def gen_invalid_index():
index_params = []
for index_type in gen_invalid_strs():
index_param = {"index_type": index_type, "params": {"nlist": 1024}}
for nlist in gen_invalid_params():
index_param = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}}
for M in gen_invalid_params():
index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
index_param = {"index_type": "RHNSW_PQ", "params": {"M": M, "efConstruction": 100}}
index_param = {"index_type": "RHNSW_SQ", "params": {"M": M, "efConstruction": 100}}
for efConstruction in gen_invalid_params():
index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
index_param = {"index_type": "RHNSW_PQ", "params": {"M": 16, "efConstruction": efConstruction}}
index_param = {"index_type": "RHNSW_SQ", "params": {"M": 16, "efConstruction": efConstruction}}
for search_length in gen_invalid_params():
index_param = {"index_type": "NSG",
"params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50,
"knng": 100}}
for out_degree in gen_invalid_params():
index_param = {"index_type": "NSG",
"params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50,
"knng": 100}}
for candidate_pool_size in gen_invalid_params():
index_param = {"index_type": "NSG", "params": {"search_length": 100, "out_degree": 40,
"candidate_pool_size": candidate_pool_size,
"knng": 100}}
index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "RHNSW_PQ", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "RHNSW_SQ", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "NSG",
"params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
"knng": 100}})
for invalid_n_trees in gen_invalid_params():
index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}})
return index_params
def gen_index():
nlists = [1, 1024, 16384]
pq_ms = [128, 64, 32, 16, 8, 4]
Ms = [5, 24, 48]
efConstructions = [100, 300, 500]
search_lengths = [10, 100, 300]
out_degrees = [5, 40, 300]
candidate_pool_sizes = [50, 100, 300]
knngs = [5, 100, 300]
index_params = []
for index_type in all_index_types:
if index_type in ["FLAT", "BIN_FLAT", "BIN_IVF_FLAT"]:
index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
elif index_type in ["IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID"]:
ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \
for nlist in nlists]
elif index_type == "IVF_PQ":
IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
for nlist in nlists \
for m in pq_ms]
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
for M in Ms \
for efConstruction in efConstructions]
elif index_type == "NSG":
nsg_params = [{"index_type": index_type,
"index_param": {"search_length": search_length, "out_degree": out_degree,
"candidate_pool_size": candidate_pool_size, "knng": knng}} \
for search_length in search_lengths \
for out_degree in out_degrees \
for candidate_pool_size in candidate_pool_sizes \
for knng in knngs]
return index_params
def gen_simple_index():
index_params = []
for i in range(len(all_index_types)):
if all_index_types[i] in binary_support():
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
dic.update({"params": default_index_params[i]})
return index_params
def gen_binary_index():
index_params = []
for i in range(len(all_index_types)):
if all_index_types[i] in binary_support():
dic = {"index_type": all_index_types[i]}
dic.update({"params": default_index_params[i]})
return index_params
def get_search_param(index_type, metric_type="L2"):
search_params = {"metric_type": metric_type}
if index_type in ivf() or index_type in binary_support():
search_params.update({"nprobe": 64})
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
search_params.update({"ef": 64})
elif index_type == "NSG":
search_params.update({"search_length": 100})
elif index_type == "ANNOY":
search_params.update({"search_k": 1000})
logging.getLogger().error("Invalid index_type.")
raise Exception("Invalid index_type.")
return search_params
def assert_equal_vector(v1, v2):
if len(v1) != len(v2):
assert False
for i in range(len(v1)):
assert abs(v1[i] - v2[i]) < epsilon
def restart_server(helm_release_name):
res = True
timeout = 120
from kubernetes import client, config
logging.getLogger().debug("Pod name: %s" % pod_name)
if pod_name is not None:
v1.delete_namespaced_pod(pod_name, namespace)
except Exception as e:
logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
res = False
return res
logging.error("Sleep 10s after pod deleted")
# check if restart successfully
pods = v1.list_namespaced_pod(namespace)
for i in pods.items:
pod_name_tmp = i.metadata.name
if pod_name_tmp == pod_name:
elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1:
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
start_time = time.time()
ready_break = False
while time.time() - start_time <= timeout:
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
if status_res.status.phase == "Running":
logging.error("Already running")
ready_break = True
if time.time() - start_time > timeout:
logging.error("Restart pod: %s timeout" % pod_name_tmp)
res = False
return res
if ready_break:
raise Exception("Pod: %s not found" % pod_name)
follow = True
pretty = True
## Requirements
# Requirements
## Requirements
* python 3.6.8+
* pip install -r requirements.txt
# How to use this Test Project
## How to Build Test Env
pytest . --level=1
sudo docker pull registry.zilliz.com/milvus/milvus-test:v0.2
sudo docker run -it -v /home/zilliz:/home/zilliz -d registry.zilliz.com/milvus/milvus-test:v0.2
or test connect function only
## How to Create Test Env docker in k8s
# 1. start milvus k8s pod
cd milvus-helm/charts/milvus
helm install --wait --timeout 300s \
--set image.repository=registry.zilliz.com/milvus/engine \
--set persistence.enabled=true \
--set image.tag=PR-3818-gpu-centos7-release \
--set image.pullPolicy=Always \
--set service.type=LoadBalancer \
-f ci/db_backend/mysql_gpu_values.yaml \
-f ci/filebeat/values.yaml \
-f test.yaml \
--namespace milvus \
milvus-ci-pr-3818-1-single-centos7-gpu .
# 2. remove milvus k8s pod
helm uninstall -n milvus milvus-test
# 3. check k8s pod status
kubectl get svc -n milvus -w milvus-test
# 4. login to pod
kubectl get pods --namespace milvus
kubectl exec -it milvus-test-writable-6cc49cfcd4-rbrns -n milvus bash
## How to Run Test cases
# Test level-1 cases
pytest . --level=1 --ip= --port=19530
# Test level-1 cases in 'test_connect.py' only
pytest test_connect.py --level=1
collect cases
pytest --day-run -qq
collect cases with docstring
## How to list test cases
# List all cases
pytest --dry-run -qq
# Collect all cases with docstring
pytest --collect-only -qq
with allure test report
# Create test report with allure
pytest --alluredir=test_out . -q -v
allure serve test_out
# Contribution getting started
## Contribution getting started
* Follow PEP-8 for naming and black for formatting.
import pdb
import pytest
from utils import *
from constants import *
nb = 6000
dim = 128
tag = "tag"
collection_id = "count_collection"
add_interval_time = 3
segment_row_count = 5000
default_fields = gen_default_fields()
default_binary_fields = gen_binary_default_fields()
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
field_name = "fload_vector"
index_name = "index_name"
uid = "collection_count"
tag = "collection_count_tag"
class TestCollectionCount:
@ -32,8 +22,8 @@ class TestCollectionCount:
def insert_count(self, request):
@ -155,7 +145,7 @@ class TestCollectionCount:
entities = gen_entities(insert_count)
res = connect.insert(collection, entities)
connect.create_index(collection, field_name, get_simple_index)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
res = connect.count_entities(collection)
assert res == insert_count
class TestCollectionCountIP:
def insert_count(self, request):
class TestCollectionCountBinary:
def insert_count(self, request):
class TestCollectionMultiCollections:
def insert_count(self, request):
@ -485,7 +475,7 @@ class TestCollectionMultiCollections:
collection_list = []
collection_num = 20
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
res = connect.insert(collection_name, entities)
@ -507,7 +497,7 @@ class TestCollectionMultiCollections:
collection_list = []
collection_num = 20
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_binary_fields)
res = connect.insert(collection_name, entities)
@ -527,16 +517,16 @@ class TestCollectionMultiCollections:
collection_list = []
collection_num = 20
for i in range(0, int(collection_num / 2)):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
res = connect.insert(collection_name, entities)
res = connect.insert(collection_name, default_entities)
for i in range(int(collection_num / 2), collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_binary_fields)
res = connect.insert(collection_name, binary_entities)
res = connect.insert(collection_name, default_binary_entities)
for i in range(collection_num):
res = connect.count_entities(collection_list[i])
assert res == nb
assert res == default_nb
from multiprocessing import Process
from multiprocessing import Process
from utils import *
dim = 128
default_segment_row_count = 100000
drop_collection_interval_time = 3
segment_row_count = 5000
collection_id = "logic"
vectors = gen_vectors(100, dim)
default_fields = gen_default_fields()
uid = "collection_logic"
def create_collection(connect, **params):
connect.create_collection(params["collection_name"], default_fields)
connect.create_collection(params["collection_name"], const.default_fields)
def search_collection(connect, **params):
status, result = connect.search(
@ -129,7 +122,7 @@ class TestCollectionLogic(object):
def gen_params(self):
collection_name = gen_unique_str("collection_id")
collection_name = gen_unique_str(uid)
top_k = 1
vectors = gen_vectors(2, dim)
param = {'collection_name': collection_name,
@ -3,25 +3,12 @@ import pdb
import threading
import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
nprobe = 1
top_k = 1
epsilon = 0.0001
tag = "1970_01_01"
nb = 6000
nlist = 1024
collection_id = "collection_stats"
field_name = "float_vector"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
class TestStatsBase:
class TestStatsBase:
@ -65,10 +52,11 @@ class TestStatsBase:
method: call collection_stats with a random collection_name, which is not in db
expected: status not ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
stats = connect.get_collection_stats(collection_name)
def test_get_collection_stats_name_invalid(self, connect, get_collection_name):
target: get collection stats where collection name is invalid
@ -88,7 +76,7 @@ class TestStatsBase:
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == 0
assert len(stats["partitions"]) == 1
assert stats["partitions"][0]["tag"] == "_default"
assert stats["partitions"][0]["tag"] == default_partition_name
assert stats["partitions"][0]["row_count"] == 0
def test_get_collection_stats_batch(self, connect, collection):
@ -97,13 +85,13 @@ class TestStatsBase:
method: add entities, check count in collection info
expected: count as expected
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
assert len(stats["partitions"]) == 1
assert stats["partitions"][0]["tag"] == "_default"
assert stats["partitions"][0]["row_count"] == nb
assert stats["partitions"][0]["tag"] == default_partition_name
assert stats["partitions"][0]["row_count"] == default_nb
def test_get_collection_stats_single(self, connect, collection):
@ -113,12 +101,12 @@ class TestStatsBase:
nb = 10
for i in range(nb):
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb
assert len(stats["partitions"]) == 1
assert stats["partitions"][0]["tag"] == "_default"
assert stats["partitions"][0]["tag"] == default_partition_name
assert stats["partitions"][0]["row_count"] == nb
def test_get_collection_stats_after_delete(self, connect, collection):
@ -127,14 +115,14 @@ class TestStatsBase:
method: add and delete entities, check count in collection info
expected: status ok, count as expected
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
status = connect.flush([collection])
delete_ids = [ids[0], ids[-1]]
connect.delete_entity_by_id(collection, delete_ids)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb - 2
assert stats["partitions"][0]["row_count"] == nb -2
assert stats["row_count"] == default_nb - 2
assert stats["partitions"][0]["row_count"] == default_nb - 2
assert stats["partitions"][0]["segments"][0]["data_size"] > 0
# TODO: enable
@ -145,20 +133,21 @@ class TestStatsBase:
method: add and delete entities, and compact collection, check count in collection info
expected: status ok, count as expected
ids = connect.insert(collection, entities)
delete_length = 1000
ids = connect.insert(collection, default_entities)
status = connect.flush([collection])
delete_ids = ids[:3000]
delete_ids = ids[:delete_length]
connect.delete_entity_by_id(collection, delete_ids)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb - 3000
assert stats["row_count"] == default_nb - delete_length
compact_before = stats["partitions"][0]["segments"][0]["data_size"]
stats = connect.get_collection_stats(collection)
compact_after = stats["partitions"][0]["segments"][0]["data_size"]
assert compact_before > compact_after
assert compact_before == compact_after
def test_get_collection_stats_after_compact_delete_one(self, connect, collection):
@ -166,7 +155,7 @@ class TestStatsBase:
method: add and delete one entity, and compact collection, check count in collection info
expected: status ok, count as expected
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
status = connect.flush([collection])
delete_ids = ids[:1]
connect.delete_entity_by_id(collection, delete_ids)
@ -187,13 +176,13 @@ class TestStatsBase:
method: call collection_stats after partition created and check partition_stats
expected: status ok, vectors added to partition
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
stats = connect.get_collection_stats(collection)
assert len(stats["partitions"]) == 2
assert stats["partitions"][1]["tag"] == tag
assert stats["partitions"][1]["row_count"] == nb
assert stats["partitions"][1]["tag"] == default_tag
assert stats["partitions"][1]["row_count"] == default_nb
def test_get_collection_stats_partitions(self, connect, collection):
@ -202,22 +191,22 @@ class TestStatsBase:
expected: status ok, vectors added to one partition but not the other
new_tag = "new_tag"
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
ids = connect.insert(collection, entities, partition_tag=tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
stats = connect.get_collection_stats(collection)
for partition in stats["partitions"]:
if partition["tag"] == tag:
assert partition["row_count"] == nb
if partition["tag"] == default_tag:
assert partition["row_count"] == default_nb
assert partition["row_count"] == 0
ids = connect.insert(collection, entities, partition_tag=new_tag)
ids = connect.insert(collection, default_entities, partition_tag=new_tag)
stats = connect.get_collection_stats(collection)
for partition in stats["partitions"]:
if partition["tag"] in [tag, new_tag]:
assert partition["row_count"] == nb
if partition["tag"] in [default_tag, new_tag]:
assert partition["row_count"] == default_nb
def test_get_collection_stats_after_index_created(self, connect, collection, get_simple_index):
@ -225,15 +214,17 @@ class TestStatsBase:
method: create collection, add vectors, create index and call collection_stats
expected: status ok, index created and shown in segments
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
if file["name"] == default_float_vec_field_name and "index_type" in file:
assert file["data_size"] > 0
assert file["index_type"] == get_simple_index["index_type"]
def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
@ -242,16 +233,17 @@ class TestStatsBase:
expected: status ok, index created and shown in segments
get_simple_index["metric_type"] = "IP"
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_simple_index.update({"metric_type": "IP"})
connect.create_index(collection, field_name, get_simple_index)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
if file["name"] == default_float_vec_field_name and "index_type" in file:
assert file["data_size"] > 0
assert file["index_type"] == get_simple_index["index_type"]
def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
@ -259,15 +251,16 @@ class TestStatsBase:
method: create collection, add binary entities, create index and call collection_stats
expected: status ok, index created and shown in segments
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
if file["name"] == default_float_vec_field_name and "index_type" in file:
assert file["data_size"] > 0
assert file["index_type"] == get_simple_index["index_type"]
def test_get_collection_stats_after_create_different_index(self, connect, collection):
@ -275,16 +268,18 @@ class TestStatsBase:
method: create collection, add vectors, create index and call collection_stats multiple times
expected: status ok, index info shown in segments
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
for index_type in ["IVF_FLAT", "IVF_SQ8"]:
connect.create_index(collection, field_name, {"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
connect.create_index(collection, default_float_vec_field_name,
{"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
stats = connect.get_collection_stats(collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
if file["name"] == default_float_vec_field_name and "index_type" in file:
assert file["data_size"] > 0
assert file["index_type"] == index_type
def test_collection_count_multi_collections(self, connect):
@ -296,14 +291,14 @@ class TestStatsBase:
collection_list = []
collection_num = 10
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
res = connect.insert(collection_name, entities)
res = connect.insert(collection_name, default_entities)
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
assert stats["partitions"][0]["row_count"] == nb
assert stats["partitions"][0]["row_count"] == default_nb
@ -317,23 +312,27 @@ class TestStatsBase:
collection_list = []
collection_num = 10
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
res = connect.insert(collection_name, entities)
res = connect.insert(collection_name, default_entities)
if i % 2:
connect.create_index(collection_name, field_name, {"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
connect.create_index(collection_name, default_float_vec_field_name,
{"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT","params":{ "nlist": 1024}, "metric_type": "L2"})
connect.create_index(collection_name, default_float_vec_field_name,
{"index_type": "IVF_FLAT","params":{"nlist": 1024}, "metric_type": "L2"})
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
if i % 2:
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
if file["name"] == default_float_vec_field_name and "index_type" in file:
assert file["index_type"] == "IVF_SQ8"
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
if file["name"] == default_float_vec_field_name and "index_type" in file:
assert file["index_type"] == "IVF_FLAT"
@ -9,18 +9,11 @@ import sklearn.preprocessing
import pytest
from utils import *
from constants import *
nb = 1
dim = 128
collection_id = "create_collection"
default_segment_row_count = 512 * 1024
drop_collection_interval_time = 3
segment_row_count = 5000
default_fields = gen_default_fields()
entities = gen_entities(nb)
uid = "create_collection"
class TestCreateCollection:
The following cases are used to test `create_collection` function
@ -42,9 +35,9 @@ class TestCreateCollection:
def get_segment_row_count(self, request):
def get_segment_row_limit(self, request):
yield request.param
def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
@ -56,10 +49,10 @@ class TestCreateCollection:
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count
"segment_row_limit": default_segment_row_limit
connect.create_collection(collection_name, fields)
@ -73,23 +66,23 @@ class TestCreateCollection:
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count
"segment_row_limit": default_segment_row_limit
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
target: test create normal collection with different fields
method: create collection with diff segment_row_count
method: create collection with diff segment_row_limit
expected: no exception raised
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_count
fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
assert connect.has_collection(collection_name)
@ -101,7 +94,7 @@ class TestCreateCollection:
expected: create status return ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
@ -115,7 +108,7 @@ class TestCreateCollection:
expected: error raised
# pdb.set_trace()
connect.insert(collection, entities)
connect.insert(collection, default_entity)
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
@ -126,7 +119,7 @@ class TestCreateCollection:
method: insert vector and create collection
expected: error raised
connect.insert(collection, entities)
connect.insert(collection, default_entity)
with pytest.raises(Exception) as e:
connect.create_collection(collection, default_fields)
@ -136,9 +129,9 @@ class TestCreateCollection:
target: test create collection, without connection
method: create collection with correct params, with a disconnected instance
expected: create raise exception
expected: error raised
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
@ -146,13 +139,23 @@ class TestCreateCollection:
target: test create collection but the collection name have already existed
method: create collection with the same collection_name
expected: create status return not ok
expected: error raised
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, default_fields)
def test_create_after_drop_collection(self, connect, collection):
target: create with the same collection name after collection dropped
method: delete, then create
expected: create success
connect.create_collection(collection, default_fields)
def test_create_collection_multithread(self, connect):
@ -165,7 +168,7 @@ class TestCreateCollection:
collection_names = []
def create():
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
@ -176,9 +179,9 @@ class TestCreateCollection:
for t in threads:
res = connect.list_collections()
for item in collection_names:
assert item in res
assert item in connect.list_collections()
class TestCreateCollectionInvalid(object):
@ -196,7 +199,7 @@ class TestCreateCollectionInvalid(object):
def get_segment_row_count(self, request):
def get_segment_row_limit(self, request):
yield request.param
@ -221,21 +224,13 @@ class TestCreateCollectionInvalid(object):
yield request.param
def test_create_collection_with_invalid_segment_row_count(self, connect, get_segment_row_count):
def test_create_collection_with_invalid_segment_row_limit(self, connect, get_segment_row_limit):
collection_name = gen_unique_str()
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_count
fields["segment_row_limit"] = get_segment_row_limit
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
# @pytest.mark.level(2)
# def test_create_collection_with_invalid_metric_type(self, connect, get_metric_type):
# collection_name = gen_unique_str()
# fields = copy.deepcopy(default_fields)
# fields["fields"][-1]["params"]["metric_type"] = get_metric_type
# with pytest.raises(Exception) as e:
# connect.create_collection(collection_name, fields)
def test_create_collection_with_invalid_dimension(self, connect, get_dim):
dimension = get_dim
@ -278,54 +273,55 @@ class TestCreateCollectionInvalid(object):
method: create collection with corrent params
expected: create status return ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
def test_create_collection_no_segment_row_count(self, connect):
def test_create_collection_no_segment_row_limit(self, connect):
target: test create collection with no segment_row_count params
method: create collection with corrent params
expected: use default default_segment_row_count
target: test create collection with no segment_row_limit params
method: create collection with correct params
expected: use default default_segment_row_limit
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
assert res["segment_row_limit"] == default_segment_row_count
assert res["segment_row_limit"] == default_server_segment_row_limit
# TODO: assert exception
def test_create_collection_limit_fields(self, connect):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
limit_num = 64
fields = copy.deepcopy(default_fields)
for i in range(limit_num):
field_name = gen_unique_str("field_name")
field = {"field": field_name, "type": DataType.INT64}
field = {"name": field_name, "type": DataType.INT64}
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
# TODO: assert exception
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
field_name = get_invalid_string
field = {"field": field_name, "type": DataType.INT64}
field = {"name": field_name, "type": DataType.INT64}
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
# TODO: assert exception
def test_create_collection_invalid_field_type(self, connect, get_field_type):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
field_type = get_field_type
field = {"field": "test_field", "type": field_type}
field = {"name": "test_field", "type": field_type}
with pytest.raises(Exception) as e:
connect.create_collection(collection_name, fields)
@ -3,15 +3,14 @@ import pytest
import logging
import itertools
from time import sleep
import threading
from multiprocessing import Process
from utils import *
from constants import *
uniq_id = "drop_collection"
default_fields = gen_default_fields()
class TestDropCollection:
The following cases are used to test `drop_collection` function
@ -48,6 +47,33 @@ class TestDropCollection:
with pytest.raises(Exception) as e:
def test_create_drop_collection_multithread(self, connect):
target: test create and drop collection with multithread
method: create and drop collection using multithread,
expected: collections are created, and dropped
threads_num = 8
threads = []
collection_names = []
def create():
collection_name = gen_unique_str(collection_id)
connect.create_collection(collection_name, default_fields)
for i in range(threads_num):
t = threading.Thread(target=create, args=())
for t in threads:
for item in collection_names:
assert not connect.has_collection(item)
class TestDropCollectionInvalid(object):
@ -60,6 +86,7 @@ class TestDropCollectionInvalid(object):
def get_collection_name(self, request):
yield request.param
def test_drop_collection_with_invalid_collectionname(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
@ -6,13 +6,9 @@ from time import sleep
import threading
from multiprocessing import Process
from utils import *
from constants import *
nb = 1000
collection_id = "info"
default_fields = gen_default_fields()
segment_row_count = 5000
field_name = "float_vector"
uid = "collection_info"
class TestInfoBase:
@ -32,9 +28,9 @@ class TestInfoBase:
def get_segment_row_count(self, request):
def get_segment_row_limit(self, request):
yield request.param
@ -62,15 +58,15 @@ class TestInfoBase:
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count
"segment_row_limit": default_segment_row_limit
connect.create_collection(collection_name, fields)
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_limit'] == segment_row_count
assert res['segment_row_limit'] == default_segment_row_limit
assert len(res["fields"]) == 2
for field in res["fields"]:
if field["type"] == filter_field:
@ -79,25 +75,25 @@ class TestInfoBase:
assert field["name"] == vector_field["name"]
assert field["params"] == vector_field["params"]
def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
target: test create normal collection with different fields
method: create collection with diff segment_row_count
method: create collection with diff segment_row_limit
expected: no exception raised
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_count
fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
# assert segment row count
res = connect.get_collection_info(collection_name)
assert res['segment_row_limit'] == get_segment_row_count
assert res['segment_row_limit'] == get_segment_row_limit
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
connect.create_index(collection, field_name, get_simple_index)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
res = connect.get_collection_info(collection)
for field in res["fields"]:
if field["field"] == field_name:
if field["name"] == default_float_vec_field_name:
index = field["indexes"][0]
assert index["index_type"] == get_simple_index["index_type"]
assert index["metric_type"] == get_simple_index["metric_type"]
@ -119,7 +115,7 @@ class TestInfoBase:
assert the value returned by get_collection_info method
expected: False
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
res = connect.get_collection_info(connect, collection_name)
@ -132,7 +128,7 @@ class TestInfoBase:
threads_num = 4
threads = []
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def get_info():
@ -161,18 +157,18 @@ class TestInfoBase:
filter_field = get_filter_field
vector_field = get_vector_field
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count
"segment_row_limit": default_segment_row_limit
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], nb, vector_field["params"]["dim"])
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
res_ids = connect.insert(collection_name, entities)
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_limit'] == segment_row_count
assert res['segment_row_limit'] == default_segment_row_limit
assert len(res["fields"]) == 2
for field in res["fields"]:
if field["type"] == filter_field:
@ -181,22 +177,22 @@ class TestInfoBase:
assert field["name"] == vector_field["name"]
assert field["params"] == vector_field["params"]
def test_create_collection_segment_row_count_after_insert(self, connect, get_segment_row_count):
def test_create_collection_segment_row_limit_after_insert(self, connect, get_segment_row_limit):
target: test create normal collection with different fields
method: create collection with diff segment_row_count
method: create collection with diff segment_row_limit
expected: no exception raised
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_fields)
fields["segment_row_limit"] = get_segment_row_count
fields["segment_row_limit"] = get_segment_row_limit
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], nb, fields["fields"][-1]["params"]["dim"])
entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"])
res_ids = connect.insert(collection_name, entities)
res = connect.get_collection_info(collection_name)
assert res['auto_id'] == True
assert res['segment_row_limit'] == get_segment_row_count
assert res['segment_row_limit'] == get_segment_row_limit
class TestInfoInvalid(object):
@ -6,13 +6,11 @@ import threading
from time import sleep
from multiprocessing import Process
from utils import *
from constants import *
collection_id = "has_collection"
default_fields = gen_default_fields()
uid = "has_collection"
class TestHasCollection:
The following cases are used to test `has_collection` function
@ -55,7 +53,7 @@ class TestHasCollection:
threads_num = 4
threads = []
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def has():
@ -6,15 +6,11 @@ import threading
from time import sleep
from multiprocessing import Process
from utils import *
from constants import *
drop_interval_time = 3
collection_id = "list_collections"
default_fields = gen_default_fields()
uid = "list_collections"
class TestListCollections:
The following cases are used to test `list_collections` function
@ -36,7 +32,7 @@ class TestListCollections:
collection_num = 50
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
assert collection_name in connect.list_collections()
@ -57,7 +53,7 @@ class TestListCollections:
assert the value returned by list_collections method
expected: False
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
assert collection_name not in connect.list_collections()
@ -72,7 +68,7 @@ class TestListCollections:
if result:
for collection_name in result:
result = connect.list_collections()
assert len(result) == 0
@ -85,7 +81,7 @@ class TestListCollections:
threads_num = 4
threads = []
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
def _list():
@ -5,18 +5,11 @@ import itertools
from time import sleep
from multiprocessing import Process
from utils import *
from constants import *
collection_id = "load_collection"
nb = 6000
default_fields = gen_default_fields()
entities = gen_entities(nb)
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
raw_vectors, binary_entities = gen_binary_entities(nb)
uid = "load_collection"
class TestLoadBase:
The following cases are used to test `load_collection` function
@ -49,10 +42,10 @@ class TestLoadBase:
method: insert and create index, load collection with correct params
expected: no error raised
connect.insert(collection, entities)
connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
@ -62,16 +55,16 @@ class TestLoadBase:
method: insert and create index, load binary_collection with correct params
expected: no error raised
connect.insert(binary_collection, binary_entities)
connect.insert(binary_collection, default_binary_entities)
for metric_type in binary_metrics():
get_binary_index["metric_type"] = metric_type
if get_binary_index["index_type"] == "BIN_IVF_FLAT" and metric_type in structure_metrics():
with pytest.raises(Exception) as e:
connect.create_index(binary_collection, binary_field_name, get_binary_index)
connect.create_index(binary_collection, default_binary_vec_field_name, get_binary_index)
connect.create_index(binary_collection, binary_field_name, get_binary_index)
connect.create_index(binary_collection, default_binary_vec_field_name, get_binary_index)
def load_empty_collection(self, connect, collection):
@ -94,7 +87,7 @@ class TestLoadBase:
def test_load_collection_not_existed(self, connect, collection):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
@ -71,6 +71,7 @@ def connect(request):
port = http_port
milvus = get_milvus(host=ip, port=port, handler=handler)
# reset_build_index_threshold(milvus)
except Exception as e:
pytest.exit("Milvus server can not connected, exit pytest ...")
@ -0,0 +1,10 @@
import utils
default_fields = utils.gen_default_fields()
default_binary_fields = utils.gen_binary_default_fields()
default_entity = utils.gen_entities(1)
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
default_entities = utils.gen_entities(utils.default_nb)
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
@ -7,23 +7,13 @@ import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
collection_id = "test_delete"
tag = "1970_01_01"
nb = 6000
field_name = default_float_vec_field_name
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "metric_type":"L2","query": gen_vectors(1, dim), "params": {"nprobe": 10}}}}
{"vector": {field_name: {"topk": 10, "metric_type":"L2", "query": gen_vectors(1, default_dim), "params": {"nprobe": 10}}}}
@ -51,7 +41,7 @@ class TestDeleteBase:
def insert_count(self, request):
@ -63,7 +53,7 @@ class TestDeleteBase:
method: add entity and delete
expected: status DELETED
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
status = connect.delete_entity_by_id(collection, [0])
assert status
@ -93,7 +83,7 @@ class TestDeleteBase:
method: add entity and delete
expected: error raised
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
collection_new = gen_unique_str()
with pytest.raises(Exception) as e:
@ -121,14 +111,14 @@ class TestDeleteBase:
method: add entities and delete one in collection, and one not in collection
expected: no error raised
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], 1]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
res_count = connect.count_entities(collection)
assert res_count == nb - 1
assert res_count == default_nb - 1
def test_insert_delete_B(self, connect, id_collection):
@ -136,8 +126,8 @@ class TestDeleteBase:
method: add entities with the same ids, and delete the id in collection
expected: no error raised, all entities deleted
ids = [1 for i in range(nb)]
res_ids = connect.insert(id_collection, entities, ids)
ids = [1 for i in range(default_nb)]
res_ids = connect.insert(id_collection, default_entities, ids)
delete_ids = [1]
status = connect.delete_entity_by_id(id_collection, delete_ids)
@ -152,7 +142,7 @@ class TestDeleteBase:
method: add one entity and delete two ids
expected: error raised
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
delete_ids = [ids[0], 1]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -166,14 +156,14 @@ class TestDeleteBase:
method: add entities and delete, then flush
expected: entity deleted and no error raised
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
res_count = connect.count_entities(collection)
assert res_count == nb - len(delete_ids)
assert res_count == default_nb - len(delete_ids)
def test_flush_after_delete_binary(self, connect, binary_collection):
@ -181,21 +171,21 @@ class TestDeleteBase:
method: add entities and delete, then flush
expected: entity deleted and no error raised
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(binary_collection, delete_ids)
assert status
res_count = connect.count_entities(binary_collection)
assert res_count == nb - len(delete_ids)
assert res_count == default_nb - len(delete_ids)
def test_insert_delete_binary(self, connect, binary_collection):
method: add entities and delete
expected: status DELETED
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(binary_collection, delete_ids)
@ -206,34 +196,34 @@ class TestDeleteBase:
expected: status DELETED
note: Not flush after delete
insert_ids = [i for i in range(nb)]
ids = connect.insert(id_collection, entities, insert_ids)
insert_ids = [i for i in range(default_nb)]
ids = connect.insert(id_collection, default_entities, insert_ids)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(id_collection, delete_ids)
assert status
new_ids = connect.insert(id_collection, entity, [ids[0]])
new_ids = connect.insert(id_collection, default_entity, [ids[0]])
assert new_ids == [ids[0]]
res_count = connect.count_entities(id_collection)
assert res_count == nb - 1
assert res_count == default_nb - 1
def test_insert_same_ids_after_delete_binary(self, connect, binary_id_collection):
method: add entities, with the same id and delete the ids
expected: status DELETED, all id deleted
insert_ids = [i for i in range(nb)]
ids = connect.insert(binary_id_collection, binary_entities, insert_ids)
insert_ids = [i for i in range(default_nb)]
ids = connect.insert(binary_id_collection, default_binary_entities, insert_ids)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(binary_id_collection, delete_ids)
assert status
new_ids = connect.insert(binary_id_collection, binary_entity, [ids[0]])
new_ids = connect.insert(binary_id_collection, default_binary_entity, [ids[0]])
assert new_ids == [ids[0]]
res_count = connect.count_entities(binary_id_collection)
assert res_count == nb - 1
assert res_count == default_nb - 1
def test_search_after_delete(self, connect, collection):
@ -241,13 +231,14 @@ class TestDeleteBase:
method: add entities and delete, then search
expected: entity deleted and no error raised
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
query = copy.deepcopy(default_single_query)
query["bool"]["must"][0]["vector"][field_name]["query"] = [entity[-1]["values"][0], entities[-1]["values"][0], entities[-1]["values"][-1]]
query["bool"]["must"][0]["vector"][field_name]["query"] =\
[default_entity[-1]["values"][0], default_entities[-1]["values"][0], default_entities[-1]["values"][-1]]
res = connect.search(collection, query)
assert len(res) == len(query["bool"]["must"][0]["vector"][field_name]["query"])
@ -260,7 +251,7 @@ class TestDeleteBase:
method: add entitys and delete, then create index
expected: vectors deleted, index created
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -272,7 +263,7 @@ class TestDeleteBase:
method: add entities and delete id serveral times
expected: entities deleted
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -288,14 +279,14 @@ class TestDeleteBase:
expected: entities deleted
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
res_count = connect.count_entities(collection)
assert res_count == nb - len(delete_ids)
assert res_count == default_nb - len(delete_ids)
res_get = connect.get_entity_by_id(collection, delete_ids)
assert res_get[0] is None
@ -305,17 +296,17 @@ class TestDeleteBase:
method: create index, insert entities, and delete
expected: entities deleted
ids = [i for i in range(nb)]
ids = [i for i in range(default_nb)]
connect.create_index(id_collection, field_name, get_simple_index)
for i in range(nb):
connect.insert(id_collection, entity, [ids[i]])
for i in range(default_nb):
connect.insert(id_collection, default_entity, [ids[i]])
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(id_collection, delete_ids)
assert status
res_count = connect.count_entities(id_collection)
assert res_count == nb - len(delete_ids)
assert res_count == default_nb - len(delete_ids)
@ -327,30 +318,30 @@ class TestDeleteBase:
method: add entitys with given tag, delete entities with the return ids
expected: entities deleted
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
res_count = connect.count_entities(collection)
assert res_count == nb - 2
assert res_count == default_nb - 2
def test_insert_default_tag_delete(self, connect, collection):
method: add entitys, delete entities with the return ids
expected: entities deleted
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
res_count = connect.count_entities(collection)
assert res_count == nb - 2
assert res_count == default_nb - 2
def test_insert_tags_delete(self, connect, collection):
@ -358,17 +349,17 @@ class TestDeleteBase:
expected: entities deleted
tag_new = "tag_new"
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
ids = connect.insert(collection, entities, partition_tag=tag)
ids_new = connect.insert(collection, entities, partition_tag=tag_new)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
ids_new = connect.insert(collection, default_entities, partition_tag=tag_new)
delete_ids = [ids[0], ids_new[0]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status
res_count = connect.count_entities(collection)
assert res_count == 2 * (nb - 1)
assert res_count == 2 * (default_nb - 1)
def test_insert_tags_index_delete(self, connect, collection, get_simple_index):
@ -376,10 +367,10 @@ class TestDeleteBase:
expected: entities deleted
tag_new = "tag_new"
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
ids = connect.insert(collection, entities, partition_tag=tag)
ids_new = connect.insert(collection, entities, partition_tag=tag_new)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
ids_new = connect.insert(collection, default_entities, partition_tag=tag_new)
connect.create_index(collection, field_name, get_simple_index)
delete_ids = [ids[0], ids_new[0]]
@ -387,7 +378,7 @@ class TestDeleteBase:
assert status
res_count = connect.count_entities(collection)
assert res_count == 2 * (nb - 1)
assert res_count == 2 * (default_nb - 1)
class TestDeleteInvalid(object):
@ -420,6 +411,7 @@ class TestDeleteInvalid(object):
with pytest.raises(Exception) as e:
status = connect.delete_entity_by_id(collection, [1, invalid_id])
def test_delete_entity_with_invalid_collection_name(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception) as e:
@ -2,28 +2,19 @@ import time
import random
import pdb
import copy
import threading
import logging
from multiprocessing import Pool, Process
import concurrent.futures
from threading import current_thread
import pytest
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
collection_id = "test_get"
tag = "1970_01_01"
nb = 6000
entity = gen_entities(1)
binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_single_query = {
"bool": {
"must": [
{"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, dim), "params": {"nprobe": 10}}}}
{"vector": {
default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "params": {"nprobe": 10}}}}
@ -34,6 +25,7 @@ class TestGetBase:
The following cases are used to test `get_entity_by_id` function
@ -48,8 +40,6 @@ class TestGetBase:
@ -62,13 +52,13 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
res_count = connect.count_entities(collection)
assert res_count == nb
assert res_count == default_nb
get_ids = [ids[get_pos]]
res = connect.get_entity_by_id(collection, get_ids)
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][get_pos])
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][get_pos])
def test_get_entity_multi_ids(self, connect, collection, get_pos):
@ -76,12 +66,12 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
def test_get_entity_parts_ids(self, connect, collection):
@ -89,12 +79,12 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = [ids[0], 1, ids[-1]]
res = connect.get_entity_by_id(collection, get_ids)
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][0])
assert_equal_vector(res[-1].get(default_float_vec_field_name), entities[-1]["values"][-1])
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][0])
assert_equal_vector(res[-1].get(default_float_vec_field_name), default_entities[-1]["values"][-1])
assert res[1] is None
def test_get_entity_limit(self, connect, collection, args):
@ -106,7 +96,7 @@ class TestGetBase:
if args["handler"] == "HTTP":
pytest.skip("skip in http mode")
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
with pytest.raises(Exception) as e:
res = connect.get_entity_by_id(collection, ids)
@ -117,13 +107,13 @@ class TestGetBase:
method: add entity, and get one id
expected: entity returned equals insert
ids = [1 for i in range(nb)]
res_ids = connect.insert(id_collection, entities, ids)
ids = [1 for i in range(default_nb)]
res_ids = connect.insert(id_collection, default_entities, ids)
get_ids = [ids[0]]
res = connect.get_entity_by_id(id_collection, get_ids)
assert len(res) == 1
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][0])
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][0])
def test_get_entity_params_same_ids(self, connect, id_collection):
@ -132,14 +122,14 @@ class TestGetBase:
expected: entity returned equals insert
ids = [1]
res_ids = connect.insert(id_collection, entity, ids)
res_ids = connect.insert(id_collection, default_entity, ids)
get_ids = [1, 1]
res = connect.get_entity_by_id(id_collection, get_ids)
assert len(res) == len(get_ids)
for i in range(len(get_ids)):
assert_equal_vector(res[i].get(default_float_vec_field_name), entity[-1]["values"][0])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entity[-1]["values"][0])
def test_get_entities_params_same_ids(self, connect, collection):
@ -147,13 +137,13 @@ class TestGetBase:
method: add entities, and get entity with the same ids
expected: entity returned equals insert
res_ids = connect.insert(collection, entities)
res_ids = connect.insert(collection, default_entities)
get_ids = [res_ids[0], res_ids[0]]
res = connect.get_entity_by_id(collection, get_ids)
assert len(res) == len(get_ids)
for i in range(len(get_ids)):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][0])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][0])
@ -167,12 +157,12 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
get_ids = [ids[0], 1, ids[-1]]
res = connect.get_entity_by_id(binary_collection, get_ids)
assert_equal_vector(res[0].get("binary_vector"), binary_entities[-1]["values"][0])
assert_equal_vector(res[-1].get("binary_vector"), binary_entities[-1]["values"][-1])
assert_equal_vector(res[0].get("binary_vector"), default_binary_entities[-1]["values"][0])
assert_equal_vector(res[-1].get("binary_vector"), default_binary_entities[-1]["values"][-1])
assert res[1] is None
@ -180,19 +170,20 @@ class TestGetBase:
The following cases are used to test `get_entity_by_id` function, with tags
def test_get_entities_tag(self, connect, collection, get_pos):
target: test.get_entity_by_id
method: add entities with tag, get
expected: entity returned
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
def test_get_entities_tag_default(self, connect, collection, get_pos):
@ -200,13 +191,13 @@ class TestGetBase:
method: add entities with default tag, get
expected: entity returned
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
def test_get_entities_tags_default(self, connect, collection, get_pos):
@ -215,14 +206,14 @@ class TestGetBase:
expected: entity returned
tag_new = "tag_new"
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
def test_get_entities_tags_A(self, connect, collection, get_pos):
@ -231,14 +222,14 @@ class TestGetBase:
expected: entity returned
tag_new = "tag_new"
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
ids = connect.insert(collection, entities, partition_tag=tag)
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
def test_get_entities_tags_B(self, connect, collection, get_pos):
@ -247,19 +238,19 @@ class TestGetBase:
expected: entity returned
tag_new = "tag_new"
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
new_entities = gen_entities(nb+1)
ids = connect.insert(collection, entities, partition_tag=tag)
ids_new = connect.insert(collection, new_entities, partition_tag=tag_new)
new_entities = gen_entities(default_nb + 1)
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
ids_new = connect.insert(collection, new_entities, partition_tag = tag_new)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
for i in range(get_pos, get_pos*2):
assert_equal_vector(res[i].get(default_float_vec_field_name), new_entities[-1]["values"][i-get_pos])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
for i in range(get_pos, get_pos * 2):
assert_equal_vector(res[i].get(default_float_vec_field_name), new_entities[-1]["values"][i - get_pos])
def test_get_entities_indexed_tag(self, connect, collection, get_simple_index, get_pos):
@ -268,27 +259,28 @@ class TestGetBase:
method: add entities with tag, get
expected: entity returned
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
The following cases are used to test `get_entity_by_id` function, with fields params
def test_get_entity_field(self, connect, collection, get_pos):
target: test.get_entity_by_id, get one
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = [ids[get_pos]]
fields = ["int64"]
@ -296,7 +288,7 @@ class TestGetBase:
# assert fields
res = res.dict()
assert res[0]["field"] == fields[0]
assert res[0]["values"] == [entities[0]["values"][get_pos]]
assert res[0]["values"] == [default_entities[0]["values"][get_pos]]
assert res[0]["type"] == DataType.INT64
def test_get_entity_fields(self, connect, collection, get_pos):
@ -305,7 +297,7 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = [ids[get_pos]]
fields = ["int64", "float", default_float_vec_field_name]
@ -315,11 +307,11 @@ class TestGetBase:
assert len(res) == len(fields)
for field in res:
if field["field"] == fields[0]:
assert field["values"] == [entities[0]["values"][get_pos]]
assert field["values"] == [default_entities[0]["values"][get_pos]]
elif field["field"] == fields[1]:
assert field["values"] == [entities[1]["values"][get_pos]]
assert field["values"] == [default_entities[1]["values"][get_pos]]
assert_equal_vector(field["values"][0], entities[-1]["values"][get_pos])
assert_equal_vector(field["values"][0], default_entities[-1]["values"][get_pos])
# TODO: assert exception
def test_get_entity_field_not_match(self, connect, collection, get_pos):
@ -328,7 +320,7 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = [ids[get_pos]]
fields = ["int1288"]
@ -342,7 +334,7 @@ class TestGetBase:
method: add entity, and get
expected: entity returned equals insert
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_ids = [ids[get_pos]]
fields = ["int1288"]
@ -355,9 +347,9 @@ class TestGetBase:
method: add entity and get
expected: empty result
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
res = connect.get_entity_by_id(collection, [1])
res = connect.get_entity_by_id(collection, [1])
assert res[0] is None
def test_get_entity_collection_not_existed(self, connect, collection):
@ -366,7 +358,7 @@ class TestGetBase:
method: add entity and get
expected: error raised
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
collection_new = gen_unique_str()
with pytest.raises(Exception) as e:
@ -377,13 +369,14 @@ class TestGetBase:
The following cases are used to test `get_entity_by_id` function, after deleted
def test_get_entity_after_delete(self, connect, collection, get_pos):
target: test.get_entity_by_id
method: add entities, and delete, get entity by the given id
expected: empty result
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[get_pos]]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -398,7 +391,7 @@ class TestGetBase:
method: add entities, and delete, get entity by the given id
expected: empty result
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = ids[:get_pos]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -414,7 +407,7 @@ class TestGetBase:
method: add entities, and delete, get entity by the given id
expected: empty result
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = ids[:get_pos]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -431,13 +424,13 @@ class TestGetBase:
method: add entities batch, create index, get
expected: entity returned
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
def test_get_entities_indexed_single(self, connect, collection, get_simple_index, get_pos):
@ -447,14 +440,31 @@ class TestGetBase:
expected: entity returned
ids = []
for i in range(nb):
ids.append(connect.insert(collection, entity)[0])
for i in range(default_nb):
ids.append(connect.insert(collection, default_entity)[0])
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entity[-1]["values"][0])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entity[-1]["values"][0])
def test_get_entities_with_deleted_ids(self, connect, id_collection):
target: test.get_entity_by_id
method: add entities ids, and delete part, get entity include the deleted id
ids = [i for i in range(default_nb)]
res_ids = connect.insert(id_collection, default_entities, ids)
status = connect.delete_entity_by_id(id_collection, [res_ids[1]])
get_ids = res_ids[:2]
res = connect.get_entity_by_id(id_collection, get_ids)
assert len(res) == len(get_ids)
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][0])
assert res[1] is None
# TODO: unable to set config
def _test_get_entities_after_delete_disable_autoflush(self, connect, collection, get_pos):
@ -463,7 +473,7 @@ class TestGetBase:
method: disable autoflush, add entities, and delete, get entity by the given id
expected: empty result
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = ids[:get_pos]
@ -472,7 +482,7 @@ class TestGetBase:
get_ids = ids[:get_pos]
res = connect.get_entity_by_id(collection, get_ids)
for i in range(get_pos):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
@ -482,9 +492,9 @@ class TestGetBase:
method: add entities with the same ids, and delete, get entity by the given id
expected: empty result
ids = [i for i in range(nb)]
ids = [i for i in range(default_nb)]
ids[0] = 1
res_ids = connect.insert(id_collection, entities, ids)
res_ids = connect.insert(id_collection, default_entities, ids)
status = connect.delete_entity_by_id(id_collection, [1])
@ -498,8 +508,8 @@ class TestGetBase:
method: add entities into partition, and delete, get entity by the given id
expected: get one entity
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
status = connect.delete_entity_by_id(collection, [ids[get_pos]])
@ -507,14 +517,16 @@ class TestGetBase:
assert res[0] is None
def test_get_entity_by_id_multithreads(self, connect, collection):
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_id = ids[100:200]
def get():
res = connect.get_entity_by_id(collection, get_id)
assert len(res) == len(get_id)
for i in range(len(res)):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][100+i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][100 + i])
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
future_results = {executor.submit(
get): i for i in range(10)}
@ -528,14 +540,14 @@ class TestGetBase:
method: thread do insert and get
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_id = ids[:1000]
def insert():
# logging.getLogger().info(current_thread().getName() + " insert")
step = 1000
for i in range(nb // step):
for i in range(default_nb // step):
group_entities = gen_entities(step, False)
connect.insert(collection, group_entities)
@ -545,7 +557,7 @@ class TestGetBase:
res = connect.get_entity_by_id(collection, get_id)
assert len(res) == len(get_id)
for i in range(len(res)):
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
for i in range(20):
@ -554,7 +566,7 @@ class TestGetBase:
def test_get_entity_by_id_insert_multi_threads(self, connect, collection):
def test_get_entity_by_id_insert_multi_threads_2(self, connect, collection):
target: test.get_entity_by_id
method: thread do insert and get
@ -572,16 +584,16 @@ class TestGetBase:
# logging.getLogger().info(current_thread().getName() + " insert")
for group_vector in group_vectors:
group_entities = [
{"field": "int64", "type": DataType.INT64, "values": [i for i in range(step)]},
{"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(step)]},
{"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": group_vector}
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(step)]},
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(step)]},
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": group_vector}
group_ids = connect.insert(collection, group_entities)
executor.submit(get, group_ids, group_entities)
step = 100
vectors = gen_vectors(nb, dimension, False)
vectors = gen_vectors(default_nb, default_dim, False)
group_vectors = [vectors[i:i + step] for i in range(0, len(vectors), step)]
task = executor.submit(insert, group_vectors)
@ -591,6 +603,7 @@ class TestGetInvalid(object):
Test get entities with invalid params
@ -620,7 +633,7 @@ class TestGetInvalid(object):
expected: raise an exception
entity_id = get_entity_id
ids = [entity_id for _ in range(nb)]
ids = [entity_id for _ in range(default_nb)]
with pytest.raises(Exception):
connect.get_entity_by_id(collection, ids)
@ -632,7 +645,7 @@ class TestGetInvalid(object):
expected: raise an exception
entity_id = get_entity_id
ids = [i for i in range(nb)]
ids = [i for i in range(default_nb)]
ids[-1] = entity_id
with pytest.raises(Exception):
connect.get_entity_by_id(collection, ids)
@ -650,4 +663,4 @@ class TestGetInvalid(object):
ids = [1]
fields = [field_name]
with pytest.raises(Exception):
res = connect.get_entity_by_id(collection, ids, fields=fields)
res = connect.get_entity_by_id(collection, ids, fields = fields)
@ -1,30 +1,22 @@
import logging
import time
import pdb
import copy
import threading
import logging
from multiprocessing import Pool, Process
import pytest
from milvus import DataType
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
collection_id = "test_insert"
tag = "1970_01_01"
insert_interval_time = 1.5
nb = 6000
field_name = default_float_vec_field_name
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
uid = "test_insert"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim),"metric_type":"L2",
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2",
"params": {"nprobe": 10}}}}
@ -89,9 +81,9 @@ class TestInsertBase:
method: insert entity into a random named collection
expected: error raised
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
connect.insert(collection_name, entities)
connect.insert(collection_name, default_entities)
def test_insert_drop_collection(self, connect, collection):
@ -100,7 +92,7 @@ class TestInsertBase:
method: insert vector and delete collection
expected: no error raised
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
assert len(ids) == 1
@ -111,7 +103,7 @@ class TestInsertBase:
method: insert vector, sleep, and delete collection
expected: no error raised
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
assert len(ids) == 1
@ -123,10 +115,15 @@ class TestInsertBase:
method: insert vector and build index
expected: no error raised
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
connect.create_index(collection, field_name, get_simple_index)
info = connect.get_collection_info(collection)
fields = info["fields"]
for field in fields:
if field["name"] == field_name:
assert field["indexes"][0] == get_simple_index
def test_insert_after_create_index(self, connect, collection, get_simple_index):
@ -136,8 +133,13 @@ class TestInsertBase:
expected: no error raised
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
info = connect.get_collection_info(collection)
fields = info["fields"]
for field in fields:
if field["name"] == field_name:
assert field["indexes"][0] == get_simple_index
def test_insert_search(self, connect, collection):
@ -146,17 +148,27 @@ class TestInsertBase:
method: insert vector, sleep, and search collection
expected: no error raised
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
res = connect.search(collection, default_single_query)
assert res
def test_insert_segment_row_count(self, connect, collection):
nb = default_segment_row_limit + 1
res_ids = connect.insert(collection, gen_entities(nb))
assert len(res_ids) == nb
stats = connect.get_collection_stats(collection)
assert len(stats['partitions'][0]['segments']) == 2
for segment in stats['partitions'][0]['segments']:
assert segment['row_count'] in [default_segment_row_limit, 1]
def insert_count(self, request):
@ -207,7 +219,7 @@ class TestInsertBase:
collection_name = gen_unique_str("test_collection")
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count,
"segment_row_limit": default_segment_row_limit,
"auto_id": True
connect.create_collection(collection_name, fields)
@ -228,10 +240,10 @@ class TestInsertBase:
method: test insert vectors twice, use customize ids first, and then use no ids
expected: error raised
ids = [i for i in range(nb)]
res_ids = connect.insert(id_collection, entities, ids)
ids = [i for i in range(default_nb)]
res_ids = connect.insert(id_collection, default_entities, ids)
with pytest.raises(Exception) as e:
res_ids_new = connect.insert(id_collection, entities)
res_ids_new = connect.insert(id_collection, default_entities)
# TODO: assert exception && enable
@ -243,7 +255,7 @@ class TestInsertBase:
expected: error raised
with pytest.raises(Exception) as e:
res_ids = connect.insert(id_collection, entities)
res_ids = connect.insert(id_collection, default_entities)
def test_insert_ids_length_not_match_batch(self, connect, id_collection):
@ -252,10 +264,10 @@ class TestInsertBase:
method: create collection and insert vectors in it
expected: raise an exception
ids = [i for i in range(1, nb)]
ids = [i for i in range(1, default_nb)]
with pytest.raises(Exception) as e:
res_ids = connect.insert(id_collection, entities, ids)
res_ids = connect.insert(id_collection, default_entities, ids)
def test_insert_ids_length_not_match_single(self, connect, collection):
@ -264,10 +276,10 @@ class TestInsertBase:
method: create collection and insert vectors in it
expected: raise an exception
ids = [i for i in range(1, nb)]
ids = [i for i in range(1, default_nb)]
with pytest.raises(Exception) as e:
res_ids = connect.insert(collection, entity, ids)
res_ids = connect.insert(collection, default_entity, ids)
def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field):
@ -282,10 +294,10 @@ class TestInsertBase:
collection_name = gen_unique_str("test_collection")
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count
"segment_row_limit": default_segment_row_limit
connect.create_collection(collection_name, fields)
entities = gen_entities_by_fields(fields["fields"], nb, dim)
entities = gen_entities_by_fields(fields["fields"], nb, default_dim)
res_ids = connect.insert(collection_name, entities)
res_count = connect.count_entities(collection_name)
@ -298,9 +310,10 @@ class TestInsertBase:
method: create collection and insert entities in it, with the partition_tag param
expected: the collection row count equals to nq
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
assert len(ids) == nb
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(collection, default_tag)
def test_insert_tag_with_ids(self, connect, id_collection):
@ -309,9 +322,9 @@ class TestInsertBase:
method: create collection and insert entities in it, with the partition_tag param
expected: the collection row count equals to nq
connect.create_partition(id_collection, tag)
ids = [i for i in range(nb)]
res_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
res_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
assert res_ids == ids
@ -321,12 +334,12 @@ class TestInsertBase:
method: create partition and insert info collection without tag params
expected: the collection row count equals to nb
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities)
assert len(ids) == nb
assert len(ids) == default_nb
res_count = connect.count_entities(collection)
assert res_count == nb
assert res_count == default_nb
def test_insert_tag_not_existed(self, connect, collection):
@ -337,7 +350,7 @@ class TestInsertBase:
tag = gen_unique_str()
with pytest.raises(Exception) as e:
ids = connect.insert(collection, entities, partition_tag=tag)
ids = connect.insert(collection, default_entities, partition_tag=tag)
def test_insert_tag_existed(self, connect, collection):
@ -346,12 +359,12 @@ class TestInsertBase:
method: create collection and insert entities in it repeatly, with the partition_tag param
expected: the collection row count equals to nq
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
res_count = connect.count_entities(collection)
assert res_count == 2 * nb
assert res_count == 2 * default_nb
def test_insert_without_connect(self, dis_connect, collection):
@ -361,7 +374,7 @@ class TestInsertBase:
expected: raise exception
with pytest.raises(Exception) as e:
ids = dis_connect.insert(collection, entities)
ids = dis_connect.insert(collection, default_entities)
def test_insert_collection_not_existed(self, connect):
@ -370,7 +383,7 @@ class TestInsertBase:
expected: error raised
with pytest.raises(Exception) as e:
ids = connect.insert(gen_unique_str("not_exist_collection"), entities)
ids = connect.insert(gen_unique_str("not_exist_collection"), default_entities)
def test_insert_dim_not_matched(self, connect, collection):
@ -378,8 +391,8 @@ class TestInsertBase:
method: the entities dimension is half of the collection dimension, check the status
expected: error raised
vectors = gen_vectors(nb, int(dim) // 2)
insert_entities = copy.deepcopy(entities)
vectors = gen_vectors(default_nb, int(default_dim) // 2)
insert_entities = copy.deepcopy(default_entities)
insert_entities[-1]["values"] = vectors
with pytest.raises(Exception) as e:
ids = connect.insert(collection, insert_entities)
@ -390,7 +403,7 @@ class TestInsertBase:
method: update entity field name
expected: error raised
tmp_entity = update_field_name(copy.deepcopy(entity), "int64", "int64new")
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", "int64new")
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -401,7 +414,7 @@ class TestInsertBase:
method: update entity field type
expected: error raised
tmp_entity = update_field_type(copy.deepcopy(entity), "int64", DataType.FLOAT)
tmp_entity = update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -412,7 +425,7 @@ class TestInsertBase:
method: update entity field value
expected: error raised
tmp_entity = update_field_value(copy.deepcopy(entity), DataType.FLOAT, 's')
tmp_entity = update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's')
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -422,7 +435,7 @@ class TestInsertBase:
method: add entity field
expected: error raised
tmp_entity = add_field(copy.deepcopy(entity))
tmp_entity = add_field(copy.deepcopy(default_entity))
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -432,7 +445,7 @@ class TestInsertBase:
method: add entity vector field
expected: error raised
tmp_entity = add_vector_field(nb, dim)
tmp_entity = add_vector_field(default_nb, default_dim)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -442,7 +455,7 @@ class TestInsertBase:
method: remove entity field
expected: error raised
tmp_entity = remove_field(copy.deepcopy(entity))
tmp_entity = remove_field(copy.deepcopy(default_entity))
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -452,7 +465,7 @@ class TestInsertBase:
method: remove entity vector field
expected: error raised
tmp_entity = remove_vector_field(copy.deepcopy(entity))
tmp_entity = remove_vector_field(copy.deepcopy(default_entity))
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -462,7 +475,7 @@ class TestInsertBase:
method: remove entity vector field
expected: error raised
tmp_entity = copy.deepcopy(entity)
tmp_entity = copy.deepcopy(default_entity)
del tmp_entity[-1]["values"]
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -473,7 +486,7 @@ class TestInsertBase:
method: remove entity vector field
expected: error raised
tmp_entity = copy.deepcopy(entity)
tmp_entity = copy.deepcopy(default_entity)
del tmp_entity[-1]["type"]
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -484,8 +497,8 @@ class TestInsertBase:
method: remove entity vector field
expected: error raised
tmp_entity = copy.deepcopy(entity)
del tmp_entity[-1]["field"]
tmp_entity = copy.deepcopy(default_entity)
del tmp_entity[-1]["name"]
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
@ -506,7 +519,7 @@ class TestInsertBase:
def insert(thread_i):
logging.getLogger().info("In thread-%d" % thread_i)
res_ids = milvus.insert(collection, entities)
res_ids = milvus.insert(collection, default_entities)
for i in range(thread_num):
@ -516,7 +529,7 @@ class TestInsertBase:
for th in threads:
res_count = milvus.count_entities(collection)
assert res_count == thread_num * nb
assert res_count == thread_num * default_nb
# TODO: unable to set config
@ -526,10 +539,102 @@ class TestInsertBase:
method: disable autoflush and insert, get entity
expected: the count is equal to 0
delete_nums = 500
ids = connect.insert(collection, entities)
res = connect.get_entity_by_id(collection, ids[:500])
assert len(res) == 0
ids = connect.insert(collection, default_entities)
res = connect.get_entity_by_id(collection, ids[:delete_nums])
assert len(res) == delete_nums
assert res[0] is None
class TestInsertBinary:
def get_binary_index(self, request):
request.param["metric_type"] = "JACCARD"
return request.param
def test_insert_binary_entities(self, connect, binary_collection):
target: test insert entities in binary collection
method: create collection and insert binary entities in it
expected: the collection row count equals to nb
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
assert connect.count_entities(binary_collection) == default_nb
def test_insert_binary_tag(self, connect, binary_collection):
target: test insert entities and create partition tag
method: create collection and insert binary entities in it, with the partition_tag param
expected: the collection row count equals to nb
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(binary_collection, default_tag)
def test_insert_binary_multi_times(self, connect, binary_collection):
target: test insert entities multi times and final flush
method: create collection and insert binary entity multi and final flush
expected: the collection row count equals to nb
for i in range(default_nb):
ids = connect.insert(binary_collection, default_binary_entity)
assert len(ids) == 1
assert connect.count_entities(binary_collection) == default_nb
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
target: test insert binary entities after build index
method: build index and insert entities
expected: no error raised
connect.create_index(binary_collection, binary_field_name, get_binary_index)
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
info = connect.get_collection_info(binary_collection)
fields = info["fields"]
for field in fields:
if field["name"] == binary_field_name:
assert field["indexes"][0] == get_binary_index
def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index):
target: test build index insert after vector
method: insert vector and build index
expected: no error raised
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
connect.create_index(binary_collection, binary_field_name, get_binary_index)
info = connect.get_collection_info(binary_collection)
fields = info["fields"]
for field in fields:
if field["name"] == binary_field_name:
assert field["indexes"][0] == get_binary_index
def test_insert_binary_search(self, connect, binary_collection):
target: test search vector after insert vector after a while
method: insert vector, sleep, and search collection
expected: no error raised
ids = connect.insert(binary_collection, default_binary_entities)
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD")
res = connect.search(binary_collection, query)
assert res
class TestInsertAsync:
@ -628,7 +733,7 @@ class TestInsertAsync:
expected: length of ids is equal to the length of vectors
collection_new = gen_unique_str()
future = connect.insert(collection_new, entities, _async=True)
future = connect.insert(collection_new, default_entities, _async=True)
with pytest.raises(Exception) as e:
result = future.result()
@ -671,14 +776,14 @@ class TestInsertMultiCollections:
collection_num = 10
collection_list = []
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection_name, entities)
ids = connect.insert(collection_name, default_entities)
assert len(ids) == nb
assert len(ids) == default_nb
count = connect.count_entities(collection_name)
assert count == nb
assert count == default_nb
def test_drop_collection_insert_vector_another(self, connect, collection):
@ -687,10 +792,10 @@ class TestInsertMultiCollections:
method: delete collection_2 and insert vector to collection_1
expected: row count equals the length of entities inserted
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection_name, entity)
ids = connect.insert(collection_name, default_entity)
assert len(ids) == 1
@ -701,10 +806,10 @@ class TestInsertMultiCollections:
method: build index and insert vector
expected: status ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
@ -714,9 +819,9 @@ class TestInsertMultiCollections:
method: build index and insert vector
expected: status ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
connect.create_index(collection, field_name, get_simple_index)
count = connect.count_entities(collection_name)
assert count == 0
@ -728,9 +833,9 @@ class TestInsertMultiCollections:
method: build index and insert vector
expected: status ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
connect.create_index(collection, field_name, get_simple_index)
count = connect.count_entities(collection)
@ -743,11 +848,11 @@ class TestInsertMultiCollections:
method: search collection and insert vector
expected: status ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
res = connect.search(collection, default_single_query)
ids = connect.insert(collection_name, entity)
ids = connect.insert(collection_name, default_entity)
count = connect.count_entities(collection_name)
assert count == 1
@ -759,9 +864,9 @@ class TestInsertMultiCollections:
method: search collection and insert vector
expected: status ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
result = connect.search(collection_name, default_single_query)
@ -771,9 +876,9 @@ class TestInsertMultiCollections:
method: search collection , sleep, and insert vector
expected: status ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
result = connect.search(collection_name, default_single_query)
@ -839,44 +944,44 @@ class TestInsertInvalid(object):
expected: raise an exception
entity_id = get_entity_id
ids = [entity_id for _ in range(nb)]
ids = [entity_id for _ in range(default_nb)]
with pytest.raises(Exception):
connect.insert(id_collection, entities, ids)
connect.insert(id_collection, default_entities, ids)
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
collection_name = get_collection_name
with pytest.raises(Exception):
connect.insert(collection_name, entity)
connect.insert(collection_name, default_entity)
def test_insert_with_invalid_tag_name(self, connect, collection, get_tag_name):
tag_name = get_tag_name
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
if tag_name is not None:
with pytest.raises(Exception):
connect.insert(collection, entity, partition_tag=tag_name)
connect.insert(collection, default_entity, partition_tag=tag_name)
connect.insert(collection, entity, partition_tag=tag_name)
connect.insert(collection, default_entity, partition_tag=tag_name)
def test_insert_with_invalid_field_name(self, connect, collection, get_field_name):
field_name = get_field_name
tmp_entity = update_field_name(copy.deepcopy(entity), "int64", get_field_name)
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
def test_insert_with_invalid_field_type(self, connect, collection, get_field_type):
field_type = get_field_type
tmp_entity = update_field_type(copy.deepcopy(entity), 'float', field_type)
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
def test_insert_with_invalid_field_value(self, connect, collection, get_field_int_value):
field_value = get_field_int_value
tmp_entity = update_field_type(copy.deepcopy(entity), 'int64', field_value)
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'int64', field_value)
with pytest.raises(Exception):
connect.insert(collection, tmp_entity)
def test_insert_with_invalid_field_vector_value(self, connect, collection, get_field_vectors_value):
tmp_entity = copy.deepcopy(entity)
tmp_entity = copy.deepcopy(default_entity)
src_vector = tmp_entity[-1]["values"]
src_vector[0][1] = get_field_vectors_value
with pytest.raises(Exception):
@ -939,21 +1044,19 @@ class TestInsertInvalidBinary(object):
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
field_name = get_field_name
tmp_entity = update_field_name(copy.deepcopy(binary_entity), "int64", get_field_name)
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
field_value = get_field_int_value
tmp_entity = update_field_type(copy.deepcopy(binary_entity), 'int64', field_value)
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value):
tmp_entity = copy.deepcopy(binary_entity)
tmp_entity = copy.deepcopy(default_binary_entity)
src_vector = tmp_entity[-1]["values"]
src_vector[0][1] = get_field_vectors_value
with pytest.raises(Exception):
@ -967,34 +1070,20 @@ class TestInsertInvalidBinary(object):
expected: raise an exception
entity_id = get_entity_id
ids = [entity_id for _ in range(nb)]
ids = [entity_id for _ in range(default_nb)]
with pytest.raises(Exception):
connect.insert(binary_id_collection, binary_entities, ids)
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
field_name = get_field_name
tmp_entity = update_field_name(copy.deepcopy(binary_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
connect.insert(binary_id_collection, default_binary_entities, ids)
def test_insert_with_invalid_field_type(self, connect, binary_collection, get_field_type):
field_type = get_field_type
tmp_entity = update_field_type(copy.deepcopy(binary_entity), 'int64', field_type)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
field_value = get_field_int_value
tmp_entity = update_field_type(copy.deepcopy(binary_entity), 'int64', field_value)
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value):
tmp_entity = copy.deepcopy(binary_entities)
tmp_entity = copy.deepcopy(default_binary_entities)
src_vector = tmp_entity[-1]["values"]
src_vector[1] = get_field_vectors_value
with pytest.raises(Exception):
@ -6,20 +6,9 @@ import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
dim = 128
segment_row_count = 100000
nb = 6000
tag = "1970_01_01"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
collection_id = "list_id_in_segment"
entity = gen_entities(1)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
uid = "list_id_in_segment"
def get_segment_id(connect, collection, nb=1, vec_type='float', index_params=None):
if vec_type != "float":
@ -30,9 +19,9 @@ def get_segment_id(connect, collection, nb=1, vec_type='float', index_params=Non
if index_params:
if vec_type == 'float':
connect.create_index(collection, field_name, index_params)
connect.create_index(collection, default_float_vec_field_name, index_params)
connect.create_index(collection, binary_field_name, index_params)
connect.create_index(collection, default_binary_vec_field_name, index_params)
stats = connect.get_collection_stats(collection)
return ids, stats["partitions"][0]["segments"][0]["id"]
@ -61,7 +50,7 @@ class TestListIdInSegmentBase:
method: call list_id_in_segment with a random collection_name, which is not in db
expected: status not ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
ids, segment_id = get_segment_id(connect, collection)
with pytest.raises(Exception) as e:
vector_ids = connect.list_id_in_segment(collection_name, segment_id)
@ -73,6 +62,7 @@ class TestListIdInSegmentBase:
def get_collection_name(self, request):
yield request.param
def test_list_id_in_segment_collection_name_invalid(self, connect, collection, get_collection_name):
target: get vector ids where collection name is invalid
@ -102,7 +92,7 @@ class TestListIdInSegmentBase:
expected: status not ok
ids, seg_id = get_segment_id(connect, collection)
# segment = gen_unique_str(collection_id)
# segment = gen_unique_str(uid)
with pytest.raises(Exception) as e:
vector_ids = connect.list_id_in_segment(collection, seg_id + 10000)
@ -129,11 +119,11 @@ class TestListIdInSegmentBase:
nb = 10
entities = gen_entities(nb)
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, entities, partition_tag=default_tag)
stats = connect.get_collection_stats(collection)
assert stats["partitions"][1]["tag"] == tag
assert stats["partitions"][1]["tag"] == default_tag
vector_ids = connect.list_id_in_segment(collection, stats["partitions"][1]["segments"][0]["id"])
# vector_ids should match ids
assert len(vector_ids) == nb
@ -157,7 +147,7 @@ class TestListIdInSegmentBase:
method: call list_id_in_segment and check if the segment contains vectors
expected: status ok
ids, seg_id = get_segment_id(connect, collection, nb=nb, index_params=get_simple_index)
ids, seg_id = get_segment_id(connect, collection, nb=default_nb, index_params=get_simple_index)
connect.list_id_in_segment(collection, seg_id)
except Exception as e:
@ -171,11 +161,11 @@ class TestListIdInSegmentBase:
method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
expected: status ok
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
stats = connect.get_collection_stats(collection)
assert stats["partitions"][1]["tag"] == tag
assert stats["partitions"][1]["tag"] == default_tag
connect.list_id_in_segment(collection, stats["partitions"][1]["segments"][0]["id"])
except Exception as e:
@ -183,7 +173,6 @@ class TestListIdInSegmentBase:
# vector_ids should match ids
def test_list_id_in_segment_after_delete_vectors(self, connect, collection):
target: get vector ids after vectors are deleted
@ -200,6 +189,24 @@ class TestListIdInSegmentBase:
assert len(vector_ids) == 1
assert vector_ids[0] == ids[1]
def test_list_id_in_segment_after_delete_vectors(self, connect, collection):
target: get vector ids after vectors are deleted
method: add vectors and delete a few, call list_id_in_segment
expected: vector_ids decreased after vectors deleted
nb = 60
delete_length = 10
ids, seg_id = get_segment_id(connect, collection, nb=nb)
delete_ids = ids[:delete_length]
status = connect.delete_entity_by_id(collection, delete_ids)
stats = connect.get_collection_stats(collection)
vector_ids = connect.list_id_in_segment(collection, stats["partitions"][0]["segments"][0]["id"])
assert len(vector_ids) == nb - delete_length
assert vector_ids[0] == ids[delete_length]
def test_list_id_in_segment_with_index_ip(self, connect, collection, get_simple_index):
@ -208,11 +215,11 @@ class TestListIdInSegmentBase:
expected: ids returned in ids inserted
get_simple_index["metric_type"] = "IP"
ids, seg_id = get_segment_id(connect, collection, nb=nb, index_params=get_simple_index)
ids, seg_id = get_segment_id(connect, collection, nb=default_nb, index_params=get_simple_index)
vector_ids = connect.list_id_in_segment(collection, seg_id)
segment_row_count = connect.get_collection_info(collection)["segment_row_limit"]
assert vector_ids[0:segment_row_count] == ids[0:segment_row_count]
segment_row_limit = connect.get_collection_info(collection)["segment_row_limit"]
assert vector_ids[0:segment_row_limit] == ids[0:segment_row_limit]
class TestListIdInSegmentBinary:
@ -245,10 +252,10 @@ class TestListIdInSegmentBinary:
method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
expected: status ok
connect.create_partition(binary_collection, tag)
connect.create_partition(binary_collection, default_tag)
nb = 10
vectors, entities = gen_binary_entities(nb)
ids = connect.insert(binary_collection, entities, partition_tag=tag)
ids = connect.insert(binary_collection, entities, partition_tag=default_tag)
stats = connect.get_collection_stats(binary_collection)
vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][1]["segments"][0]["id"])
@ -275,7 +282,7 @@ class TestListIdInSegmentBinary:
method: call list_id_in_segment and check if the segment contains vectors
expected: status ok
ids, seg_id = get_segment_id(connect, binary_collection, nb=nb, index_params=get_jaccard_index, vec_type='binary')
ids, seg_id = get_segment_id(connect, binary_collection, nb=default_nb, index_params=get_jaccard_index, vec_type='binary')
vector_ids = connect.list_id_in_segment(binary_collection, seg_id)
@ -285,11 +292,11 @@ class TestListIdInSegmentBinary:
method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
expected: status ok
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
stats = connect.get_collection_stats(binary_collection)
assert stats["partitions"][1]["tag"] == tag
assert stats["partitions"][1]["tag"] == default_tag
vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][1]["segments"][0]["id"])
# vector_ids should match ids
@ -9,36 +9,27 @@ import numpy as np
from milvus import DataType
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
top_k_limit = 16384
collection_id = "search"
tag = "1970_01_01"
insert_interval_time = 1.5
nb = 6000
top_k = 10
uid = "test_search"
nq = 1
nprobe = 1
epsilon = 0.001
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
default_fields = gen_default_fields()
search_param = {"nprobe": 1}
entity = gen_entities(1, is_normal=True)
raw_vector, binary_entity = gen_binary_entities(1)
entities = gen_entities(nb, is_normal=True)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq)
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
entities = gen_entities(default_nb, is_normal=True)
raw_vectors, binary_entities = gen_binary_entities(default_nb)
default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq)
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k, nq)
def init_data(connect, collection, nb=6000, partition_tags=None, auto_id=True):
def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True):
Generate entities and add it in collection
global entities
if nb == 6000:
if nb == 1200:
insert_entities = entities
insert_entities = gen_entities(nb, is_normal=True)
@ -56,14 +47,14 @@ def init_data(connect, collection, nb=6000, partition_tags=None, auto_id=True):
return insert_entities, ids
def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None):
def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None):
Generate entities and add it in collection
ids = []
global binary_entities
global raw_vectors
if nb == 6000:
if nb == 1200:
insert_entities = binary_entities
insert_raw_vectors = raw_vectors
@ -141,7 +132,7 @@ class TestSearchBase:
params=[1, 10, 16385]
params=[1, 10]
def get_top_k(self, request):
yield request.param
@ -155,7 +146,7 @@ class TestSearchBase:
def test_search_flat(self, connect, collection, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, change top-k value
target: test basic search function, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: the length of the result is top_k
@ -163,7 +154,26 @@ class TestSearchBase:
nq = get_nq
entities, ids = init_data(connect, collection)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
if top_k <= top_k_limit:
if top_k <= max_top_k:
res = connect.search(collection, query)
assert len(res[0]) == top_k
assert res[0]._distances[0] <= epsilon
assert check_id_result(res[0], ids[0])
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
def test_search_flat_top_k(self, connect, collection, get_nq):
target: test basic search function, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: the length of the result is top_k
top_k = 16385
nq = get_nq
entities, ids = init_data(connect, collection)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
if top_k <= max_top_k:
res = connect.search(collection, query)
assert len(res[0]) == top_k
assert res[0]._distances[0] <= epsilon
@ -174,7 +184,7 @@ class TestSearchBase:
def test_search_field(self, connect, collection, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, change top-k value
target: test basic search function, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: the length of the result is top_k
@ -182,7 +192,7 @@ class TestSearchBase:
nq = get_nq
entities, ids = init_data(connect, collection)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
if top_k <= top_k_limit:
if top_k <= max_top_k:
res = connect.search(collection, query, fields=["float_vector"])
assert len(res[0]) == top_k
assert res[0]._distances[0] <= epsilon
@ -198,7 +208,7 @@ class TestSearchBase:
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: the length of the result is top_k
@ -206,13 +216,13 @@ class TestSearchBase:
nq = get_nq
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -233,16 +243,16 @@ class TestSearchBase:
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type,
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type,
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: add vectors into collection, search with the given vectors, check the result
expected: the length of the result is top_k, search collection with partition tag return empty
@ -250,14 +260,14 @@ class TestSearchBase:
nq = get_nq
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -266,13 +276,13 @@ class TestSearchBase:
assert len(res[0]) >= top_k
assert res[0]._distances[0] < epsilon
assert check_id_result(res[0], ids[0])
res = connect.search(collection, query, partition_tags=[tag])
res = connect.search(collection, query, partition_tags=[default_tag])
assert len(res) == nq
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: the length of the result is top_k
@ -280,15 +290,15 @@ class TestSearchBase:
nq = get_nq
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
entities, ids = init_data(connect, collection, partition_tags=tag)
connect.create_partition(collection, default_tag)
entities, ids = init_data(connect, collection, partition_tags=default_tag)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
for tags in [[tag], [tag, "new_tag"]]:
if top_k > top_k_limit:
for tags in [[default_tag], [default_tag, "new_tag"]]:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query, partition_tags=tags)
@ -301,7 +311,7 @@ class TestSearchBase:
def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search with the given vectors and tag (tag name not existed in collection), check the result
expected: error raised
@ -309,7 +319,7 @@ class TestSearchBase:
nq = get_nq
entities, ids = init_data(connect, collection)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query, partition_tags=["new_tag"])
@ -320,7 +330,7 @@ class TestSearchBase:
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search collection with the given vectors and tags, check the result
expected: the length of the result is top_k
@ -328,16 +338,16 @@ class TestSearchBase:
nq = 2
new_tag = "new_tag"
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
entities, ids = init_data(connect, collection, partition_tags=tag)
entities, ids = init_data(connect, collection, partition_tags=default_tag)
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -353,7 +363,7 @@ class TestSearchBase:
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search collection with the given vectors and tags, check the result
expected: the length of the result is top_k
@ -362,7 +372,7 @@ class TestSearchBase:
tag = "tag"
new_tag = "new_tag"
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
connect.create_partition(collection, new_tag)
@ -371,7 +381,7 @@ class TestSearchBase:
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -389,7 +399,7 @@ class TestSearchBase:
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, change top-k value
target: test basic search function, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: the length of the result is top_k
@ -397,7 +407,7 @@ class TestSearchBase:
nq = get_nq
entities, ids = init_data(connect, collection)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP")
if top_k <= top_k_limit:
if top_k <= max_top_k:
res = connect.search(collection, query)
assert len(res[0]) == top_k
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
@ -409,7 +419,7 @@ class TestSearchBase:
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: the length of the result is top_k
@ -417,14 +427,14 @@ class TestSearchBase:
nq = get_nq
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
entities, ids = init_data(connect, collection)
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -437,7 +447,7 @@ class TestSearchBase:
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: add vectors into collection, search with the given vectors, check the result
expected: the length of the result is top_k, search collection with partition tag return empty
@ -445,16 +455,16 @@ class TestSearchBase:
nq = get_nq
metric_type = "IP"
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
entities, ids = init_data(connect, collection)
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type,
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -463,13 +473,13 @@ class TestSearchBase:
assert len(res[0]) >= top_k
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
assert check_id_result(res[0], ids[0])
res = connect.search(collection, query, partition_tags=[tag])
res = connect.search(collection, query, partition_tags=[default_tag])
assert len(res) == nq
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
target: test basic search fuction, all the search params is corrent, test all index params, and build
target: test basic search function, all the search params is corrent, test all index params, and build
method: search collection with the given vectors and tags, check the result
expected: the length of the result is top_k
@ -478,17 +488,17 @@ class TestSearchBase:
metric_type = "IP"
new_tag = "new_tag"
index_type = get_simple_index["index_type"]
if index_type == "IVF_PQ":
if index_type in skip_pq():
pytest.skip("Skip PQ")
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, new_tag)
entities, ids = init_data(connect, collection, partition_tags=tag)
entities, ids = init_data(connect, collection, partition_tags=default_tag)
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
if top_k > top_k_limit:
if top_k > max_top_k:
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -518,7 +528,7 @@ class TestSearchBase:
method: search with the random collection_name, which is not in db
expected: status not ok
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
res = connect.search(collection_name, default_query)
@ -531,8 +541,8 @@ class TestSearchBase:
nq = 2
search_param = {"nprobe": 1}
entities, ids = init_data(connect, collection, nb=nq)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param)
inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param)
distance_0 = l2(vecs[0], inside_vecs[0])
distance_1 = l2(vecs[0], inside_vecs[1])
res = connect.search(collection, query)
@ -549,11 +559,11 @@ class TestSearchBase:
entities, ids = init_data(connect, id_collection, auto_id=False)
connect.create_index(id_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param)
inside_vecs = entities[-1]["values"]
min_distance = 1.0
min_id = None
for i in range(nb):
for i in range(default_nb):
tmp_dis = l2(vecs[0], inside_vecs[i])
if min_distance > tmp_dis:
min_distance = tmp_dis
@ -577,9 +587,9 @@ class TestSearchBase:
metirc_type = "IP"
search_param = {"nprobe": 1}
entities, ids = init_data(connect, collection, nb=nq)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type,
inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param)
distance_0 = ip(vecs[0], inside_vecs[0])
distance_1 = ip(vecs[0], inside_vecs[1])
res = connect.search(collection, query)
@ -598,12 +608,12 @@ class TestSearchBase:
get_simple_index["metric_type"] = metirc_type
connect.create_index(id_collection, field_name, get_simple_index)
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type,
inside_vecs = entities[-1]["values"]
max_distance = 0
max_id = None
for i in range(nb):
for i in range(default_nb):
tmp_dis = ip(vecs[0], inside_vecs[i])
if max_distance < tmp_dis:
max_distance = tmp_dis
@ -627,7 +637,7 @@ class TestSearchBase:
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="JACCARD")
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD")
res = connect.search(binary_collection, query)
assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
@ -643,7 +653,7 @@ class TestSearchBase:
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="L2")
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2")
with pytest.raises(Exception) as e:
res = connect.search(binary_collection, query)
@ -659,7 +669,7 @@ class TestSearchBase:
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
distance_0 = hamming(query_int_vectors[0], int_vectors[0])
distance_1 = hamming(query_int_vectors[0], int_vectors[1])
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="HAMMING")
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING")
res = connect.search(binary_collection, query)
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
@ -675,7 +685,7 @@ class TestSearchBase:
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
distance_0 = substructure(query_int_vectors[0], int_vectors[0])
distance_1 = substructure(query_int_vectors[0], int_vectors[1])
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="SUBSTRUCTURE")
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUBSTRUCTURE")
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
@ -708,7 +718,7 @@ class TestSearchBase:
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="SUPERSTRUCTURE")
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUPERSTRUCTURE")
res = connect.search(binary_collection, query)
assert len(res[0]) == 0
@ -743,7 +753,7 @@ class TestSearchBase:
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="TANIMOTO")
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO")
res = connect.search(binary_collection, query)
assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon
@ -759,7 +769,7 @@ class TestSearchBase:
top_k = 10
threads_num = 4
threads = []
collection = gen_unique_str(collection_id)
collection = gen_unique_str(uid)
uri = "tcp://%s:%s" % (args["ip"], args["port"])
# create collection
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
@ -793,7 +803,7 @@ class TestSearchBase:
top_k = 10
threads_num = 4
threads = []
collection = gen_unique_str(collection_id)
collection = gen_unique_str(uid)
uri = "tcp://%s:%s" % (args["ip"], args["port"])
# create collection
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
@ -825,10 +835,10 @@ class TestSearchBase:
top_k = 10
nq = 20
for i in range(num):
collection = gen_unique_str(collection_id + str(i))
collection = gen_unique_str(uid + str(i))
connect.create_collection(collection, default_fields)
entities, ids = init_data(connect, collection)
assert len(ids) == nb
assert len(ids) == default_nb
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
res = connect.search(collection, query)
assert len(res) == nq
@ -885,7 +895,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
res = connect.search(collection, default_query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
def test_query_wrong_format(self, connect, collection):
@ -971,7 +981,7 @@ class TestSearchDSL(object):
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
def test_query_term_values_parts_in(self, connect, collection):
@ -981,11 +991,11 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
expr = {"must": [gen_default_vector_expr(default_query),
gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])]}
gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
@ -997,7 +1007,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
expr = {
"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, nb)])]}
"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, default_nb)])]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
@ -1029,7 +1039,7 @@ class TestSearchDSL(object):
expected: Exception raised
expr = {"must": [gen_default_vector_expr(default_query),
gen_default_term_expr(keyword="terrm", values=[i for i in range(nb // 2)])]}
gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]}
query = update_query_expr(default_query, expr=expr)
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -1066,17 +1076,17 @@ class TestSearchDSL(object):
connect.create_collection(collection_term, term_fields)
term_entities = add_field(entities, field_name="term")
ids = connect.insert(collection_term, term_entities)
assert len(ids) == nb
assert len(ids) == default_nb
count = connect.count_entities(collection_term)
assert count == nb
term_param = {"term": {"term": {"values": [i for i in range(nb // 2)]}}}
assert count == default_nb
term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}}
expr = {"must": [gen_default_vector_expr(default_query),
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection_term, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
@ -1153,7 +1163,7 @@ class TestSearchDSL(object):
expected: 0
entities, ids = init_data(connect, collection)
ranges = {"GT": nb, "LT": 0}
ranges = {"GT": default_nb, "LT": 0}
range = gen_default_range_expr(ranges=ranges)
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
@ -1180,7 +1190,7 @@ class TestSearchDSL(object):
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
def test_query_range_one_field_not_existed(self, connect, collection):
@ -1189,7 +1199,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
range = gen_default_range_expr()
range["range"].update({"a": {"GT": 1, "LT": nb // 2}})
range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}})
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
with pytest.raises(Exception) as e:
@ -1210,12 +1220,12 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
term_first = gen_default_term_expr()
term_second = gen_default_term_expr(values=[i for i in range(nb // 3)])
term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)])
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
@ -1226,7 +1236,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
term_first = gen_default_term_expr()
term_second = gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])
term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
@ -1241,7 +1251,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
term_first = gen_default_term_expr()
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb // 2, nb)])
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(default_nb // 2, default_nb)])
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
@ -1256,8 +1266,8 @@ class TestSearchDSL(object):
expected: pass
entities, ids = init_data(connect, collection)
term_first = {"int64": {"values": [i for i in range(nb // 2)]}}
term_second = {"float": {"values": [float(i) for i in range(nb // 2, nb)]}}
term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}}
term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}}
term = update_term_expr({"term": {}}, [term_first, term_second])
expr = {"must": [gen_default_vector_expr(default_query), term]}
query = update_query_expr(default_query, expr=expr)
@ -1273,12 +1283,12 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
range_one = gen_default_range_expr()
range_two = gen_default_range_expr(ranges={"GT": 1, "LT": nb // 3})
range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3})
expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
@ -1289,7 +1299,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
range_one = gen_default_range_expr()
range_two = gen_default_range_expr(ranges={"GT": nb // 2, "LT": nb})
range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
@ -1305,7 +1315,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
range_first = gen_default_range_expr()
range_second = gen_default_range_expr(field="float", ranges={"GT": nb // 2, "LT": nb})
range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb})
expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
@ -1320,12 +1330,11 @@ class TestSearchDSL(object):
expected: pass
entities, ids = init_data(connect, collection)
range_first = {"int64": {"GT": 0, "LT": nb // 2}}
range_second = {"float": {"GT": nb / 2, "LT": float(nb)}}
range_first = {"int64": {"GT": 0, "LT": default_nb // 2}}
range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}}
range = update_range_expr({"range": {}}, [range_first, range_second])
expr = {"must": [gen_default_vector_expr(default_query), range]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -1344,12 +1353,12 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
term = gen_default_term_expr()
range = gen_default_range_expr(ranges={"GT": -1, "LT": nb // 2})
range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2})
expr = {"must": [gen_default_vector_expr(default_query), term, range]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == top_k
assert len(res[0]) == default_top_k
def test_query_single_term_range_no_common(self, connect, collection):
@ -1359,7 +1368,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
term = gen_default_term_expr()
range = gen_default_range_expr(ranges={"GT": nb // 2, "LT": nb})
range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
expr = {"must": [gen_default_vector_expr(default_query), term, range]}
query = update_query_expr(default_query, expr=expr)
res = connect.search(collection, query)
@ -1380,7 +1389,7 @@ class TestSearchDSL(object):
entities, ids = init_data(connect, collection)
vector1 = default_query
vector2 = gen_query_vectors(field_name, entities, top_k, nq=2)
vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2)
expr = {
"must": [vector1, vector2]
@ -1509,8 +1518,9 @@ class TestSearchInvalid(object):
with pytest.raises(Exception) as e:
res = connect.search(collection_name, default_query)
def test_search_with_invalid_tag(self, connect, collection):
# TODO(yukun)
def _test_search_with_invalid_tag(self, connect, collection):
tag = " "
with pytest.raises(Exception) as e:
res = connect.search(collection, default_query, partition_tags=tag)
@ -1541,7 +1551,7 @@ class TestSearchInvalid(object):
def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
target: test search fuction, with the wrong top_k
target: test search function, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
@ -1564,7 +1574,7 @@ class TestSearchInvalid(object):
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
target: test search fuction, with the wrong nprobe
target: test search function, with the wrong nprobe
method: search with nprobe
expected: raise an error, and the connection is normal
@ -1572,16 +1582,18 @@ class TestSearchInvalid(object):
index_type = get_simple_index["index_type"]
if index_type in ["FLAT"]:
pytest.skip("skip in FLAT index")
if index_type != search_params["index_type"]:
pytest.skip("skip if index_type not matched")
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params=search_params["search_params"])
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
def test_search_with_invalid_params_binary(self, connect, binary_collection):
target: test search fuction, with the wrong nprobe
target: test search function, with the wrong nprobe
method: search with nprobe
expected: raise an error, and the connection is normal
@ -1589,15 +1601,15 @@ class TestSearchInvalid(object):
index_type = "BIN_IVF_FLAT"
int_vectors, entities, ids = init_binary_data(connect, binary_collection)
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
connect.create_index(binary_collection, binary_field_name, {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 1024}})
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, search_params={"nprobe": 0}, metric_type="JACCARD")
connect.create_index(binary_collection, binary_field_name, {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}})
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, search_params={"nprobe": 0}, metric_type="JACCARD")
with pytest.raises(Exception) as e:
res = connect.search(binary_collection, query)
def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
target: test search fuction, with empty search params
target: test search function, with empty search params
method: search with params
expected: raise an error, and the connection is normal
@ -1608,7 +1620,7 @@ class TestSearchInvalid(object):
pytest.skip("skip in FLAT index")
entities, ids = init_data(connect, collection)
connect.create_index(collection, field_name, get_simple_index)
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={})
query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={})
with pytest.raises(Exception) as e:
res = connect.search(collection, query)
@ -7,16 +7,7 @@ from multiprocessing import Pool, Process
import pytest
from utils import *
dim = 128
index_file_size = 10
collection_id = "mysql_failure"
nprobe = 1
tag = "1970_01_01"
class TestMysql:
The following cases are used to test mysql failure
@ -33,7 +24,7 @@ class TestMysql:
big_nb = 20000
index_param = {"nlist": 1024, "m": 16}
index_type = IndexType.IVF_PQ
vectors = gen_vectors(big_nb, dim)
vectors = gen_vectors(big_nb, default_dim)
status, ids = connect.insert(collection, vectors, ids=[i for i in range(big_nb)])
status = connect.flush([collection])
assert status.OK()
@ -3,131 +3,313 @@ import random
import pdb
import threading
import logging
import json
from multiprocessing import Pool, Process
import pytest
from utils import *
dim = 128
index_file_size = 10
collection_id = "test_partition_restart"
nprobe = 1
tag = "1970_01_01"
uid = "wal"
insert_interval_time = 1.5
big_nb = 100000
field_name = "float_vector"
big_entities = gen_entities(big_nb)
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
class TestRestartBase:
The following cases are used to test `create_partition` function
@pytest.fixture(scope="function", autouse=True)
def skip_check(self, connect, args):
@pytest.fixture(scope="module", autouse=True)
def skip_check(self, args):
if "service_name" not in args or not args["service_name"]:
reason = "Skip if service name not provided"
if args["service_name"].find("shards") != -1:
reason = "Skip restart cases in shards mode"
def _test_create_partition_insert_restart(self, connect, collection, args):
def _test_insert_flush(self, connect, collection, args):
target: return the same row count after server restart
method: call function: create partition, then insert, restart server and assert row count
expected: status ok, and row count keep the same
method: call function: create collection, then insert/flush, restart server and assert row count
expected: row count keep the same
status = connect.create_partition(collection, tag)
assert status.OK()
nq = 1000
vectors = gen_vectors(nq, dim)
ids = [i for i in range(nq)]
status, ids = connect.insert(collection, vectors, ids, partition_tag=tag)
assert status.OK()
status = connect.flush([collection])
assert status.OK()
status, res = connect.count_entities(collection)
assert res == nq
# restart server
if restart_server(args["service_name"]):
logging.getLogger().info("Restart success")
logging.getLogger().info("Restart failed")
# assert row count again
# debug
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
status, res = new_connect.count_entities(collection)
assert status.OK()
assert res == nq
def _test_during_creating_index_restart(self, connect, collection, args):
target: return the same row count after server restart
method: call function: insert, flush, and create index, server do restart during creating index
expected: row count, vector-id, index info keep the same
# reset auto_flush_interval
# auto_flush_interval = 100
get_ids_length = 500
timeout = 60
big_nb = 20000
index_param = {"nlist": 1024, "m": 16}
index_type = IndexType.IVF_PQ
# status, res_set = connect.set_config("db_config", "auto_flush_interval", auto_flush_interval)
# assert status.OK()
# status, res_get = connect.get_config("db_config", "auto_flush_interval")
# assert status.OK()
# assert res_get == str(auto_flush_interval)
# insert and create index
vectors = gen_vectors(big_nb, dim)
status, ids = connect.insert(collection, vectors, ids=[i for i in range(big_nb)])
status = connect.flush([collection])
assert status.OK()
status, res_count = connect.count_entities(collection)
ids = connect.insert(collection, default_entities)
ids = connect.insert(collection, default_entities)
res_count = connect.count_entities(collection)
assert status.OK()
assert res_count == big_nb
logging.getLogger().info("Start create index async")
status = connect.create_index(collection, index_type, index_param, _async=True)
assert res_count == 2 * nb
# restart server
logging.getLogger().info("Before restart server")
if restart_server(args["service_name"]):
logging.getLogger().info("Restart success")
logging.getLogger().info("Restart failed")
# check row count, index_type, vertor-id after server restart
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
status, res_count = new_connect.count_entities(collection)
assert status.OK()
assert res_count == big_nb
status, res_info = new_connect.get_index_info(collection)
assert res_info._params == index_param
assert res_info._collection_name == collection
assert res_info._index_type == index_type
logging.getLogger().info("Start restart server")
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count = new_connect.count_entities(collection)
assert res_count == 2 * nb
def _test_insert_during_flushing(self, connect, collection, args):
target: flushing will recover
method: call function: create collection, then insert/flushing, restart server and assert row count
expected: row count equals 0
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
timeout = 300
start_time = time.time()
while new_connect.count_entities(collection) != big_nb and (time.time() - start_time < timeout):
res_count_3 = new_connect.count_entities(collection)
assert res_count_3 == big_nb
def _test_delete_during_flushing(self, connect, collection, args):
target: flushing will recover
method: call function: create collection, then delete/flushing, restart server and assert row count
expected: row count equals (nb - delete_length)
# disable_autoflush()
ids = connect.insert(collection, big_entities)
delete_length = 1000
delete_ids = ids[big_nb//4:big_nb//4+delete_length]
delete_res = connect.delete_entity_by_id(collection, delete_ids)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
timeout = 100
start_time = time.time()
i = 1
while time.time() - start_time < timeout:
stauts, stats = new_connect.get_collection_stats(collection)
index_name = stats["partitions"][0]["segments"][0]["index_name"]
if index_name == "PQ":
i += 1
if time.time() - start_time >= timeout:
assert False
get_ids = random.sample(ids, get_ids_length)
status, res = new_connect.get_entity_by_id(collection, get_ids)
while new_connect.count_entities(collection) != big_nb - delete_length and (time.time() - start_time < timeout):
if new_connect.count_entities(collection) == big_nb - delete_length:
res_count_3 = new_connect.count_entities(collection)
assert res_count_3 == big_nb - delete_length
def _test_during_indexed(self, connect, collection, args):
target: flushing will recover
method: call function: create collection, then indexed, restart server and assert row count
expected: row count equals nb
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.create_index(collection, field_name, default_index)
res_count = connect.count_entities(collection)
stats = connect.get_collection_stats(collection)
# logging.getLogger().info(stats)
# pdb.set_trace()
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
assert new_connect.count_entities(collection) == big_nb
stats = connect.get_collection_stats(collection)
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
if file["index_type"] != default_index["index_type"]:
assert False
assert True
def _test_during_indexing(self, connect, collection, args):
target: flushing will recover
method: call function: create collection, then indexing, restart server and assert row count
expected: row count equals nb, server contitue to build index after restart
# disable_autoflush()
loop = 5
for i in range(loop):
ids = connect.insert(collection, big_entities)
connect.create_index(collection, field_name, default_index, _async=True)
res_count = connect.count_entities(collection)
stats = connect.get_collection_stats(collection)
# logging.getLogger().info(stats)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
assert res_count_2 == loop * big_nb
status = new_connect._cmd("status")
assert json.loads(status)["indexing"] == True
# timeout = 100
# start_time = time.time()
# while time.time() - start_time < timeout:
# time.sleep(5)
# assert new_connect.count_entities(collection) == loop * big_nb
# stats = connect.get_collection_stats(collection)
# assert stats["row_count"] == loop * big_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# # logging.getLogger().info(file)
# if file["field"] == field_name and file["name"] != "_raw":
# assert file["data_size"] > 0
# if file["index_type"] != default_index["index_type"]:
# continue
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["field"] == field_name and file["name"] != "_raw":
# assert file["data_size"] > 0
# if file["index_type"] != default_index["index_type"]:
# assert False
# else:
# assert True
def _test_delete_flush_during_compacting(self, connect, collection, args):
target: verify server work after restart during compaction
method: call function: create collection, then delete/flush/compacting, restart server and assert row count
call `compact` again, compact pass
expected: row count equals (nb - delete_length)
# disable_autoflush()
ids = connect.insert(collection, big_entities)
delete_length = 1000
loop = 10
for i in range(loop):
delete_ids = ids[i*delete_length:(i+1)*delete_length]
delete_res = connect.delete_entity_by_id(collection, delete_ids)
connect.compact(collection, _async=True)
res_count = connect.count_entities(collection)
assert res_count == big_nb - delete_length*loop
info = connect.get_collection_stats(collection)
size_old = info["partitions"][0]["segments"][0]["data_size"]
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
assert res_count_2 == big_nb - delete_length*loop
info = connect.get_collection_stats(collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(collection)
assert status.OK()
for index, item_id in enumerate(get_ids):
assert_equal_vector(res[index], vectors[item_id])
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert size_before > size_after
def _test_insert_during_flushing_multi_collections(self, connect, args):
target: flushing will recover
method: call function: create collections, then insert/flushing, restart server and assert row count
expected: row count equals 0
# disable_autoflush()
collection_num = 2
collection_list = []
for i in range(collection_num):
collection_name = gen_unique_str(uid)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection_name, big_entities)
connect.flush(collection_list, _async=True)
res_count = connect.count_entities(collection_list[-1])
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection_list[-1])
timeout = 300
start_time = time.time()
while time.time() - start_time < timeout:
count_list = []
break_flag = True
for index, name in enumerate(collection_list):
tmp_count = new_connect.count_entities(name)
if tmp_count != big_nb:
break_flag = False
if break_flag == True:
for name in collection_list:
assert new_connect.count_entities(name) == big_nb
def _test_insert_during_flushing_multi_partitions(self, connect, collection, args):
target: flushing will recover
method: call function: create collection/partition, then insert/flushing, restart server and assert row count
expected: row count equals 0
# disable_autoflush()
partitions_num = 2
partitions = []
for i in range(partitions_num):
tag_tmp = gen_unique_str()
connect.create_partition(collection, tag_tmp)
ids = connect.insert(collection, big_entities, partition_tag=tag_tmp)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
timeout = 300
start_time = time.time()
while new_connect.count_entities(collection) != big_nb * 2 and (time.time() - start_time < timeout):
res_count_3 = new_connect.count_entities(collection)
assert res_count_3 == big_nb * 2
@ -5,28 +5,15 @@ import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
dim = 128
index_file_size = 10
nprobe = 1
top_k = 1
tag = "1970_01_01"
nb = 6000
nq = 2
segment_row_count = 5000
entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
default_binary_fields = gen_binary_default_fields()
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim), "metric_type":"L2",
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type":"L2",
"params": {"nprobe": 10}}}}
@ -34,12 +21,12 @@ default_single_query = {
default_binary_single_query = {
"bool": {
"must": [
{"vector": {binary_field_name: {"topk": 10, "query": gen_binary_vectors(1, dim), "metric_type":"JACCARD",
"params": {"nprobe": 10}}}}
{"vector": {binary_field_name: {"topk": 10, "query": gen_binary_vectors(1, default_dim),
"metric_type":"JACCARD", "params": {"nprobe": 10}}}}
default_query, default_query_vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
default_query, default_query_vecs = gen_query_vectors(binary_field_name, default_binary_entities, 1, 2)
def ip_query():
@ -83,6 +70,14 @@ class TestCompactBase:
def get_collection_name(self, request):
yield request.param
def get_threshold(self, request):
yield request.param
def test_compact_collection_name_invalid(self, connect, get_collection_name):
@ -94,7 +89,20 @@ class TestCompactBase:
with pytest.raises(Exception) as e:
status = connect.compact(collection_name)
# assert not status.OK()
def test_compact_threshold_invalid(self, connect, collection, get_threshold):
target: compact collection with invalid name
method: compact with invalid threshold
expected: exception raised
threshold = get_threshold
if threshold != None:
with pytest.raises(Exception) as e:
status = connect.compact(collection, threshold)
def test_add_entity_and_compact(self, connect, collection):
@ -104,7 +112,7 @@ class TestCompactBase:
expected: data_size before and after Compact
# vector = gen_single_vector(dim)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
assert len(ids) == 1
# get collection info before compact
@ -125,8 +133,7 @@ class TestCompactBase:
method: add entities and compact collection
expected: data_size before and after Compact
# entities = gen_vector(nb, dim)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
# get collection info before compact
info = connect.get_collection_stats(collection)
@ -147,8 +154,8 @@ class TestCompactBase:
method: add entities, delete a few and compact collection
expected: status ok, data size maybe is smaller after compact
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -167,7 +174,35 @@ class TestCompactBase:
size_after = info["partitions"][0]["data_size"]
assert(size_before >= size_after)
def test_insert_delete_part_and_compact_threshold(self, connect, collection):
target: test add entities, delete part of them and compact
method: add entities, delete a few and compact collection
expected: status ok, data size maybe is smaller after compact
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status.OK()
# get collection info before compact
info = connect.get_collection_stats(collection)
size_before = info["partitions"][0]["data_size"]
status = connect.compact(collection, 0.1)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["data_size"]
assert(size_before >= size_after)
def test_insert_delete_all_and_compact(self, connect, collection):
@ -176,8 +211,8 @@ class TestCompactBase:
method: add entities, delete all and compact collection
expected: status ok, no data size in collection info because collection is empty
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
status = connect.delete_entity_by_id(collection, ids)
assert status.OK()
@ -200,14 +235,13 @@ class TestCompactBase:
method: add entities, delete half of entities in partition and compact collection
expected: status ok, data_size less than the older version
connect.create_partition(collection, tag)
assert connect.has_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
assert connect.has_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
info = connect.get_collection_stats(collection)
delete_ids = ids[:3000]
delete_ids = ids[:default_nb//2]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status.OK()
@ -242,15 +276,14 @@ class TestCompactBase:
expected: status ok, index description no change, data size smaller after compact
count = 10
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
# get collection info before compact
info = connect.get_collection_stats(collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
delete_ids = ids[:1500]
delete_ids = ids[:default_nb//2]
status = connect.delete_entity_by_id(collection, delete_ids)
assert status.OK()
@ -258,7 +291,6 @@ class TestCompactBase:
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before >= size_after)
@ -269,7 +301,7 @@ class TestCompactBase:
method: add entity and compact collection twice
expected: status ok, data size no change
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
# get collection info before compact
info = connect.get_collection_stats(collection)
@ -295,7 +327,7 @@ class TestCompactBase:
method: add entities, delete part and compact collection twice
expected: status ok, data size smaller after first compact, no change after second
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(collection, delete_ids)
@ -345,8 +377,8 @@ class TestCompactBase:
method: after compact operation, add entity
expected: status ok, entity added
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
# get collection info before compact
info = connect.get_collection_stats(collection)
@ -357,10 +389,10 @@ class TestCompactBase:
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before == size_after)
ids = connect.insert(collection, entity)
ids = connect.insert(collection, default_entity)
res = connect.count_entities(collection)
assert res == nb+1
assert res == default_nb+1
def test_index_creation_after_compact(self, connect, collection, get_simple_index):
@ -369,7 +401,7 @@ class TestCompactBase:
method: after compact operation, create index
expected: status ok, index description no change
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
status = connect.delete_entity_by_id(collection, ids[:10])
assert status.OK()
@ -387,8 +419,8 @@ class TestCompactBase:
method: after compact operation, delete entities
expected: status ok, entities deleted
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
status = connect.compact(collection)
assert status.OK()
@ -405,14 +437,15 @@ class TestCompactBase:
method: after compact operation, search vector
expected: status ok
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
status = connect.compact(collection)
assert status.OK()
query = copy.deepcopy(default_single_query)
query["bool"]["must"][0]["vector"][field_name]["query"] = [entity[-1]["values"][0], entities[-1]["values"][0],
query["bool"]["must"][0]["vector"][field_name]["query"] = [default_entity[-1]["values"][0],
res = connect.search(collection, query)
assert len(res) == len(query["bool"]["must"][0]["vector"][field_name]["query"])
@ -434,7 +467,7 @@ class TestCompactBinary:
method: add vector and compact collection
expected: status ok, vector added
ids = connect.insert(binary_collection, binary_entity)
ids = connect.insert(binary_collection, default_binary_entity)
assert len(ids) == 1
# get collection info before compact
@ -454,8 +487,8 @@ class TestCompactBinary:
method: add entities and compact collection
expected: status ok, entities added
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
# get collection info before compact
info = connect.get_collection_stats(binary_collection)
@ -474,8 +507,8 @@ class TestCompactBinary:
method: add entities, delete a few and compact collection
expected: status ok, data size is smaller after compact
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(binary_collection, delete_ids)
@ -503,8 +536,8 @@ class TestCompactBinary:
method: add entities, delete all and compact collection
expected: status ok, no data size in collection info because collection is empty
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
status = connect.delete_entity_by_id(binary_collection, ids)
assert status.OK()
@ -526,7 +559,7 @@ class TestCompactBinary:
method: add entity and compact collection twice
expected: status ok
ids = connect.insert(binary_collection, binary_entity)
ids = connect.insert(binary_collection, default_binary_entity)
assert len(ids) == 1
# get collection info before compact
@ -552,8 +585,8 @@ class TestCompactBinary:
method: add entities, delete part and compact collection twice
expected: status ok, data size smaller after first compact, no change after second
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
delete_ids = [ids[0], ids[-1]]
status = connect.delete_entity_by_id(binary_collection, delete_ids)
@ -609,7 +642,7 @@ class TestCompactBinary:
method: after compact operation, add entity
expected: status ok, entity added
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
# get collection info before compact
info = connect.get_collection_stats(binary_collection)
@ -620,10 +653,10 @@ class TestCompactBinary:
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert(size_before == size_after)
ids = connect.insert(binary_collection, binary_entity)
ids = connect.insert(binary_collection, default_binary_entity)
res = connect.count_entities(binary_collection)
assert res == nb + 1
assert res == default_nb + 1
def test_delete_entities_after_compact(self, connect, binary_collection):
@ -632,7 +665,7 @@ class TestCompactBinary:
method: after compact operation, delete entities
expected: status ok, entities deleted
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
status = connect.compact(binary_collection)
assert status.OK()
@ -651,16 +684,16 @@ class TestCompactBinary:
method: after compact operation, search vector
expected: status ok
ids = connect.insert(binary_collection, binary_entities)
assert len(ids) == nb
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
status = connect.compact(binary_collection)
assert status.OK()
query_vecs = [raw_vectors[0]]
distance = jaccard(query_vecs[0], raw_vectors[0])
query_vecs = [default_raw_binary_vectors[0]]
distance = jaccard(query_vecs[0], default_raw_binary_vectors[0])
query = copy.deepcopy(default_binary_single_query)
query["bool"]["must"][0]["vector"][binary_field_name]["query"] = [binary_entities[-1]["values"][0],
query["bool"]["must"][0]["vector"][binary_field_name]["query"] = [default_binary_entities[-1]["values"][0],
res = connect.search(binary_collection, query)
assert abs(res[0]._distances[0]-distance) <= epsilon
@ -672,13 +705,14 @@ class TestCompactBinary:
method: after compact operation, search vector
expected: status ok
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
status = connect.compact(collection)
query = ip_query()
query["bool"]["must"][0]["vector"][field_name]["query"] = [entity[-1]["values"][0], entities[-1]["values"][0],
query["bool"]["must"][0]["vector"][field_name]["query"] = [default_entity[-1]["values"][0],
res = connect.search(collection, query)
assert len(res) == len(query["bool"]["must"][0]["vector"][field_name]["query"])
@ -8,15 +8,7 @@ import pytest
from utils import *
import ujson
dim = 128
index_file_size = 10
nprobe = 1
top_k = 1
tag = "1970_01_01"
nb = 6000
class TestCacheConfig:
@ -35,8 +27,8 @@ class TestCacheConfig:
reset configs so the tests are stable
relpy = connect.set_config("cache", "cache_size", '4GB')
config_value = connect.get_config("cache", "cache_size")
relpy = connect.set_config("cache.cache_size", '4GB')
config_value = connect.get_config("cache.cache_size")
assert config_value == '4GB'
#relpy = connect.set_config("cache", "insert_buffer_size", '2GB')
#config_value = connect.get_config("cache", "insert_buffer_size")
@ -52,7 +44,7 @@ class TestCacheConfig:
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "cache_size")
config_value = connect.get_config(config+str(".cache_size"))
def test_get_cache_size_invalid_child_key(self, connect, collection):
@ -64,7 +56,7 @@ class TestCacheConfig:
invalid_configs = ["Cpu_cache_size", "cpu cache_size", "cpucachecapacity"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("cache", config)
config_value = connect.get_config("cache."+config)
def test_get_cache_size_valid(self, connect, collection):
@ -73,7 +65,7 @@ class TestCacheConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("cache", "cache_size")
config_value = connect.get_config("cache.cache_size")
assert config_value
@ -86,7 +78,7 @@ class TestCacheConfig:
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "insert_buffer_size")
config_value = connect.get_config(config+".insert_buffer_size")
def test_get_insert_buffer_size_invalid_child_key(self, connect, collection):
@ -98,7 +90,7 @@ class TestCacheConfig:
invalid_configs = ["Insert_buffer size", "insert buffer_size", "insertbuffersize"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("cache", config)
config_value = connect.get_config("cache."+config)
def test_get_insert_buffer_size_valid(self, connect, collection):
@ -107,7 +99,7 @@ class TestCacheConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("cache", "insert_buffer_size")
config_value = connect.get_config("cache.insert_buffer_size")
assert config_value
@ -120,7 +112,7 @@ class TestCacheConfig:
invalid_configs = ["preloadtable", "preload collection "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("cache", config)
config_value = connect.get_config("cache."+config)
def test_get_preload_collection_valid(self, connect, collection):
@ -129,7 +121,7 @@ class TestCacheConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("cache", "preload_collection")
config_value = connect.get_config("cache.preload_collection")
assert config_value == ''
@ -165,7 +157,7 @@ class TestCacheConfig:
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config(config, "cache_size", '4294967296')
relpy = connect.set_config(config+".cache_size", '4294967296')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -179,7 +171,7 @@ class TestCacheConfig:
invalid_configs = ["abc", 1]
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config("cache", config, '4294967296')
relpy = connect.set_config("cache."+config, '4294967296')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -190,8 +182,8 @@ class TestCacheConfig:
expected: status ok, set successfully
relpy = connect.set_config("cache", "cache_size", '2147483648')
config_value = connect.get_config("cache", "cache_size")
relpy = connect.set_config("cache.cache_size", '2147483648')
config_value = connect.get_config("cache.cache_size")
assert config_value == '2GB'
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -204,12 +196,12 @@ class TestCacheConfig:
for i in range(20):
relpy = connect.set_config("cache", "cache_size", '4294967296')
config_value = connect.get_config("cache", "cache_size")
relpy = connect.set_config("cache.cache_size", '4294967296')
config_value = connect.get_config("cache.cache_size")
assert config_value == '4294967296'
for i in range(20):
relpy = connect.set_config("cache", "cache_size", '2147483648')
config_value = connect.get_config("cache", "cache_size")
relpy = connect.set_config("cache.cache_size", '2147483648')
config_value = connect.get_config("cache.cache_size")
assert config_value == '2147483648'
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -224,7 +216,7 @@ class TestCacheConfig:
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config(config, "insert_buffer_size", '1073741824')
relpy = connect.set_config(config+".insert_buffer_size", '1073741824')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -235,7 +227,7 @@ class TestCacheConfig:
expected: status ok, set successfully
relpy = connect.set_config("cache", "insert_buffer_size", '2GB')
relpy = connect.set_config("cache.insert_buffer_size", '2GB')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -248,10 +240,10 @@ class TestCacheConfig:
for i in range(20):
with pytest.raises(Exception) as e:
relpy = connect.set_config("cache", "insert_buffer_size", '1GB')
relpy = connect.set_config("cache.insert_buffer_size", '1GB')
for i in range(20):
with pytest.raises(Exception) as e:
relpy = connect.set_config("cache", "insert_buffer_size", '2GB')
relpy = connect.set_config("cache.insert_buffer_size", '2GB')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -265,7 +257,7 @@ class TestCacheConfig:
mem_total = self.get_memory_total(connect)
with pytest.raises(Exception) as e:
relpy = connect.set_config("cache", "cache_size", str(int(mem_total + 1)+''))
relpy = connect.set_config("cache.cache_size", str(int(mem_total + 1)+''))
@ -292,7 +284,7 @@ class TestGPUConfig:
invalid_configs = ["Engine_config", "engine config"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "gpu_search_threshold")
config_value = connect.get_config(config+".gpu_search_threshold")
def test_get_gpu_search_threshold_invalid_child_key(self, connect, collection):
@ -306,7 +298,7 @@ class TestGPUConfig:
invalid_configs = ["Gpu_search threshold", "gpusearchthreshold"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("gpu", config)
config_value = connect.get_config("gpu."+config)
def test_get_gpu_search_threshold_valid(self, connect, collection):
@ -317,7 +309,7 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
config_value = connect.get_config("gpu", "gpu_search_threshold")
config_value = connect.get_config("gpu.gpu_search_threshold")
assert config_value
@ -336,7 +328,7 @@ class TestGPUConfig:
invalid_configs = ["abc", 1]
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config("gpu", config, 1000)
relpy = connect.set_config("gpu."+config, 1000)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -351,7 +343,7 @@ class TestGPUConfig:
invalid_configs = ["Engine_config", "engine config"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config(config, "gpu_search_threshold", 1000)
relpy = connect.set_config(config+".gpu_search_threshold", 1000)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -363,8 +355,8 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
relpy = connect.set_config("gpu", "gpu_search_threshold", 2000)
config_value = connect.get_config("gpu", "gpu_search_threshold")
relpy = connect.set_config("gpu.gpu_search_threshold", 2000)
config_value = connect.get_config("gpu.gpu_search_threshold")
assert config_value == '2000'
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -377,10 +369,10 @@ class TestGPUConfig:
for i in [-1, "1000\n", "1000\t", "1000.0", 1000.35]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("gpu", "use_blas_threshold", i)
relpy = connect.set_config("gpu.use_blas_threshold", i)
if str(connect._cmd("mode")) == "GPU":
with pytest.raises(Exception) as e:
relpy = connect.set_config("gpu", "gpu_search_threshold", i)
relpy = connect.set_config("gpu.gpu_search_threshold", i)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -388,8 +380,8 @@ class TestGPUConfig:
reset configs so the tests are stable
relpy = connect.set_config("gpu", "cache_size", 1)
config_value = connect.get_config("gpu", "cache_size")
relpy = connect.set_config("gpu.cache_size", 1)
config_value = connect.get_config("gpu.cache_size")
assert config_value == '1'
#follows can not be changed
@ -416,7 +408,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "enable")
config_value = connect.get_config(config+".enable")
def test_get_gpu_enable_invalid_child_key(self, connect, collection):
@ -430,7 +422,7 @@ class TestGPUConfig:
invalid_configs = ["Enab_le", "enab_le ", "disable", "true"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("gpu", config)
config_value = connect.get_config("gpu."+config)
def test_get_gpu_enable_valid(self, connect, collection):
@ -441,7 +433,7 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
config_value = connect.get_config("gpu", "enable")
config_value = connect.get_config("gpu.enable")
assert config_value == "true" or config_value == "false"
@ -457,7 +449,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "cache_size")
config_value = connect.get_config(config+".cache_size")
def test_get_cache_size_invalid_child_key(self, connect, collection):
@ -471,7 +463,7 @@ class TestGPUConfig:
invalid_configs = ["Cache_capacity", "cachecapacity"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("gpu", config)
config_value = connect.get_config("gpu."+config)
def test_get_cache_size_valid(self, connect, collection):
@ -482,7 +474,7 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
config_value = connect.get_config("gpu", "cache_size")
config_value = connect.get_config("gpu.cache_size")
def test_get_search_devices_invalid_parent_key(self, connect, collection):
@ -497,7 +489,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "search_devices")
config_value = connect.get_config(config+".search_devices")
def test_get_search_devices_invalid_child_key(self, connect, collection):
@ -511,7 +503,7 @@ class TestGPUConfig:
invalid_configs = ["Search_resources"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("gpu", config)
config_value = connect.get_config("gpu."+config)
def test_get_search_devices_valid(self, connect, collection):
@ -522,7 +514,7 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
config_value = connect.get_config("gpu", "search_devices")
config_value = connect.get_config("gpu.search_devices")
@ -538,7 +530,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config(config, "build_index_devices")
config_value = connect.get_config(config+".build_index_devices")
def test_get_build_index_devices_invalid_child_key(self, connect, collection):
@ -552,7 +544,7 @@ class TestGPUConfig:
invalid_configs = ["Build_index_resources"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("gpu", config)
config_value = connect.get_config("gpu."+config)
def test_get_build_index_devices_valid(self, connect, collection):
@ -563,7 +555,7 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
config_value = connect.get_config("gpu", "build_index_devices")
config_value = connect.get_config("gpu.build_index_devices")
assert config_value
@ -586,7 +578,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config(config, "enable", "true")
relpy = connect.set_config(config+".enable", "true")
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -602,7 +594,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config("gpu", config, "true")
relpy = connect.set_config("gpu."+config, "true")
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -616,7 +608,7 @@ class TestGPUConfig:
pytest.skip("Only support GPU mode")
for i in [-1, -2, 100]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("gpu", "enable", i)
relpy = connect.set_config("gpu.enable", i)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -631,7 +623,7 @@ class TestGPUConfig:
valid_configs = ["off", "False", "0", "nO", "on", "True", 1, "yES"]
for config in valid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config("gpu", "enable", config)
relpy = connect.set_config("gpu.enable", config)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -647,7 +639,7 @@ class TestGPUConfig:
for config in invalid_configs:
with pytest.raises(Exception) as e:
relpy = connect.set_config(config, "cache_size", 2)
relpy = connect.set_config(config+".cache_size", 2)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -659,7 +651,7 @@ class TestGPUConfig:
if str(connect._cmd("mode")) == "CPU":
pytest.skip("Only support GPU mode")
relpy = connect.set_config("gpu", "cache_size", 2)
relpy = connect.set_config("gpu.cache_size", 2)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -786,7 +778,7 @@ class TestNetworkConfig:
invalid_configs = ["Address", "addresses", "address "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("network", config)
config_value = connect.get_config("network."+config)
def test_get_address_valid(self, connect, collection):
@ -795,7 +787,7 @@ class TestNetworkConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("network", "bind.address")
config_value = connect.get_config("network.bind.address")
def test_get_port_invalid_child_key(self, connect, collection):
@ -807,7 +799,7 @@ class TestNetworkConfig:
invalid_configs = ["Port", "PORT", "port "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("network", config)
config_value = connect.get_config("network."+config)
def test_get_port_valid(self, connect, collection):
@ -816,7 +808,7 @@ class TestNetworkConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("network", "http.port")
config_value = connect.get_config("network.http.port")
assert config_value
@ -829,7 +821,7 @@ class TestNetworkConfig:
invalid_configs = ["webport", "Web_port", "http port "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("network", config)
config_value = connect.get_config("network."+config)
def test_get_http_port_valid(self, connect, collection):
@ -838,7 +830,7 @@ class TestNetworkConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("network", "http.port")
config_value = connect.get_config("network.http.port")
assert config_value
@ -863,7 +855,7 @@ class TestNetworkConfig:
expected: status not ok
with pytest.raises(Exception) as e:
relpy = connect.set_config("network", "child_key", 19530)
relpy = connect.set_config("network.child_key", 19530)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -873,7 +865,7 @@ class TestNetworkConfig:
method: call set_config correctly
expected: status ok, set successfully
relpy = connect.set_config("network", "bind.address", '')
relpy = connect.set_config("network.bind.address", '')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_port_valid(self, connect, collection):
@ -883,7 +875,7 @@ class TestNetworkConfig:
expected: status ok, set successfully
for valid_port in [1025, 65534, 12345, "19530"]:
relpy = connect.set_config("network", "http.port", valid_port)
relpy = connect.set_config("network.http.port", valid_port)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_port_invalid(self, connect, collection):
@ -895,7 +887,7 @@ class TestNetworkConfig:
for invalid_port in [1024, 65535, "0", "True", "100000"]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("network", "http.port", invalid_port)
relpy = connect.set_config("network.http.port", invalid_port)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_http_port_valid(self, connect, collection):
@ -905,7 +897,7 @@ class TestNetworkConfig:
expected: status ok, set successfully
for valid_http_port in [1025, 65534, "12345", 19121]:
relpy = connect.set_config("network", "http.port", valid_http_port)
relpy = connect.set_config("network.http.port", valid_http_port)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_http_port_invalid(self, connect, collection):
@ -916,7 +908,7 @@ class TestNetworkConfig:
for invalid_http_port in [1024, 65535, "0", "True", "1000000"]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("network", "http.port", invalid_http_port)
relpy = connect.set_config("network.http.port", invalid_http_port)
class TestGeneralConfig:
@ -940,7 +932,7 @@ class TestGeneralConfig:
invalid_configs = ["backend_Url", "backend-url", "meta uri "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("general", config)
config_value = connect.get_config("general."+config)
def test_get_meta_uri_valid(self, connect, collection):
@ -949,7 +941,7 @@ class TestGeneralConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("general", "meta_uri")
config_value = connect.get_config("general.meta_uri")
assert config_value
@ -962,7 +954,7 @@ class TestGeneralConfig:
invalid_configs = ["time", "time_zone "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("general", config)
config_value = connect.get_config("general."+config)
def test_get_timezone_valid(self, connect, collection):
@ -971,7 +963,7 @@ class TestGeneralConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("general", "timezone")
config_value = connect.get_config("general.timezone")
assert "UTC" in config_value
@ -989,7 +981,7 @@ class TestGeneralConfig:
for invalid_timezone in ["utc++8", "UTC++8"]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("general", "timezone", invalid_timezone)
relpy = connect.set_config("general.timezone", invalid_timezone)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -1000,7 +992,7 @@ class TestGeneralConfig:
expected: status not ok
with pytest.raises(Exception) as e:
relpy = connect.set_config("general", "child_key", 1)
relpy = connect.set_config("general.child_key", 1)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -1010,7 +1002,7 @@ class TestGeneralConfig:
method: call set_config correctly
expected: status ok, set successfully
relpy = connect.set_config("general", "meta_uri", 'sqlite://:@:/')
relpy = connect.set_config("general.meta_uri", 'sqlite://:@:/')
class TestStorageConfig:
@ -1034,7 +1026,7 @@ class TestStorageConfig:
invalid_configs = ["Primary_path", "primarypath", "pa_th "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("storage", config)
config_value = connect.get_config("storage."+config)
def test_get_path_valid(self, connect, collection):
@ -1043,7 +1035,7 @@ class TestStorageConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("storage", "path")
config_value = connect.get_config("storage.path")
assert config_value
@ -1056,7 +1048,7 @@ class TestStorageConfig:
invalid_configs = ["autoFlushInterval", "auto_flush", "auto_flush interval "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("storage", config)
config_value = connect.get_config("storage."+config)
def test_get_auto_flush_interval_valid(self, connect, collection):
@ -1065,7 +1057,7 @@ class TestStorageConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("storage", "auto_flush_interval")
config_value = connect.get_config("storage.auto_flush_interval")
@ -1081,7 +1073,7 @@ class TestStorageConfig:
expected: status not ok
with pytest.raises(Exception) as e:
relpy = connect.set_config("storage", "child_key", "")
relpy = connect.set_config("storage.child_key", "")
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -1091,7 +1083,7 @@ class TestStorageConfig:
method: call set_config correctly
expected: status ok, set successfully
relpy = connect.set_config("storage", "path", '/var/lib/milvus')
relpy = connect.set_config("storage.path", '/var/lib/milvus')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_auto_flush_interval_valid(self, connect, collection):
@ -1102,8 +1094,8 @@ class TestStorageConfig:
for valid_auto_flush_interval in [2, 1]:
relpy = connect.set_config("storage", "auto_flush_interval", valid_auto_flush_interval)
config_value = connect.get_config("storage", "auto_flush_interval")
relpy = connect.set_config("storage.auto_flush_interval", valid_auto_flush_interval)
config_value = connect.get_config("storage.auto_flush_interval")
assert config_value == str(valid_auto_flush_interval)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -1115,7 +1107,7 @@ class TestStorageConfig:
for invalid_auto_flush_interval in [-1, "1.5", "invalid", "1+2"]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("storage", "auto_flush_interval", invalid_auto_flush_interval)
relpy = connect.set_config("storage.auto_flush_interval", invalid_auto_flush_interval)
class TestMetricConfig:
@ -1139,7 +1131,7 @@ class TestMetricConfig:
invalid_configs = ["enablemonitor", "Enable_monitor", "en able "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("metric", config)
config_value = connect.get_config("metric."+config)
def test_get_enable_valid(self, connect, collection):
@ -1148,7 +1140,7 @@ class TestMetricConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("metric", "enable")
config_value = connect.get_config("metric.enable")
assert config_value
@ -1161,7 +1153,7 @@ class TestMetricConfig:
invalid_configs = ["Add ress", "addresses", "add ress "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("metric", config)
config_value = connect.get_config("metric."+config)
def test_get_address_valid(self, connect, collection):
@ -1170,7 +1162,7 @@ class TestMetricConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("metric", "address")
config_value = connect.get_config("metric.address")
assert config_value
@ -1183,7 +1175,7 @@ class TestMetricConfig:
invalid_configs = ["Po_rt", "PO_RT", "po_rt "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("metric", config)
config_value = connect.get_config("metric."+config)
def test_get_port_valid(self, connect, collection):
@ -1192,7 +1184,7 @@ class TestMetricConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("metric", "port")
config_value = connect.get_config("metric.port")
assert config_value
@ -1209,7 +1201,7 @@ class TestMetricConfig:
expected: status not ok
with pytest.raises(Exception) as e:
relpy = connect.set_config("metric", "child_key", 19530)
relpy = connect.set_config("metric.child_key", 19530)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_enable_valid(self, connect, collection):
@ -1219,7 +1211,7 @@ class TestMetricConfig:
expected: status ok, set successfully
for valid_enable in ["false", "true"]:
relpy = connect.set_config("metric", "enable", valid_enable)
relpy = connect.set_config("metric.enable", valid_enable)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -1229,7 +1221,7 @@ class TestMetricConfig:
method: call set_config correctly
expected: status ok, set successfully
relpy = connect.set_config("metric", "address", '')
relpy = connect.set_config("metric.address", '')
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_port_valid(self, connect, collection):
@ -1239,7 +1231,7 @@ class TestMetricConfig:
expected: status ok, set successfully
for valid_port in [1025, 65534, "19530", "9091"]:
relpy = connect.set_config("metric", "port", valid_port)
relpy = connect.set_config("metric.port", valid_port)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_port_invalid(self, connect, collection):
@ -1250,7 +1242,7 @@ class TestMetricConfig:
for invalid_port in [1024, 65535, "0", "True", "100000"]:
with pytest.raises(Exception) as e:
relpy = connect.set_config("metric", "port", invalid_port)
relpy = connect.set_config("metric.port", invalid_port)
class TestWALConfig:
@ -1274,7 +1266,7 @@ class TestWALConfig:
invalid_configs = ["enabled", "Enab_le", "enable_"]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("wal", config)
config_value = connect.get_config("wal."+config)
def test_get_enable_valid(self, connect, collection):
@ -1283,7 +1275,7 @@ class TestWALConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("wal", "enable")
config_value = connect.get_config("wal.enable")
assert config_value
@ -1296,7 +1288,7 @@ class TestWALConfig:
invalid_configs = ["recovery-error-ignore", "Recovery error_ignore", "recoveryxerror_ignore "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("wal", config)
config_value = connect.get_config("wal."+config)
def test_get_recovery_error_ignore_valid(self, connect, collection):
@ -1305,7 +1297,7 @@ class TestWALConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("wal", "recovery_error_ignore")
config_value = connect.get_config("wal.recovery_error_ignore")
assert config_value
@ -1318,7 +1310,7 @@ class TestWALConfig:
invalid_configs = ["buffersize", "Buffer size", "buffer size "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("wal", config)
config_value = connect.get_config("wal."+config)
def test_get_buffer_size_valid(self, connect, collection):
@ -1327,7 +1319,7 @@ class TestWALConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("wal", "buffer_size")
config_value = connect.get_config("wal.buffer_size")
assert config_value
@ -1340,7 +1332,7 @@ class TestWALConfig:
invalid_configs = ["wal", "Wal_path", "wal_path "]
for config in invalid_configs:
with pytest.raises(Exception) as e:
config_value = connect.get_config("wal", config)
config_value = connect.get_config("wal."+config)
def test_get_wal_path_valid(self, connect, collection):
@ -1349,7 +1341,7 @@ class TestWALConfig:
method: call get_config correctly
expected: status ok
config_value = connect.get_config("wal", "path")
config_value = connect.get_config("wal.path")
assert config_value
@ -1366,7 +1358,7 @@ class TestWALConfig:
expected: status not ok
with pytest.raises(Exception) as e:
relpy = connect.set_config("wal", "child_key", 256)
relpy = connect.set_config("wal.child_key", 256)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_enable_valid(self, connect, collection):
@ -1376,7 +1368,7 @@ class TestWALConfig:
expected: status ok, set successfully
for valid_enable in ["false", "true"]:
relpy = connect.set_config("wal", "enable", valid_enable)
relpy = connect.set_config("wal.enable", valid_enable)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_recovery_error_ignore_valid(self, connect, collection):
@ -1386,7 +1378,7 @@ class TestWALConfig:
expected: status ok, set successfully
for valid_recovery_error_ignore in ["false", "true"]:
relpy = connect.set_config("wal", "recovery_error_ignore", valid_recovery_error_ignore)
relpy = connect.set_config("wal.recovery_error_ignore", valid_recovery_error_ignore)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
def test_set_buffer_size_valid_A(self, connect, collection):
@ -1396,7 +1388,7 @@ class TestWALConfig:
expected: status ok, set successfully
for valid_buffer_size in ["64MB", "128MB", "4096MB", "1000MB", "256MB"]:
relpy = connect.set_config("wal", "buffer_size", valid_buffer_size)
relpy = connect.set_config("wal.buffer_size", valid_buffer_size)
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
@ -1406,5 +1398,5 @@ class TestWALConfig:
method: call set_config correctly
expected: status ok, set successfully
relpy = connect.set_config("wal", "path", "/var/lib/milvus/wal")
relpy = connect.set_config("wal.path", "/var/lib/milvus/wal")
@ -5,32 +5,18 @@ import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
index_file_size = 10
collection_id = "test_flush"
nprobe = 1
tag = "1970_01_01"
top_k = 1
nb = 6000
tag = "partition_tag"
field_name = "float_vector"
entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
default_single_query = {
"bool": {
"must": [
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim), "metric_type":"L2","params": {"nprobe": 10}}}}
{"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim),
"metric_type": "L2", "params": {"nprobe": 10}}}}
class TestFlushBase:
@ -77,8 +63,8 @@ class TestFlushBase:
method: flush collection with no vectors
expected: no error raised
ids = connect.insert(collection, entities)
assert len(ids) == nb
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
status = connect.delete_entity_by_id(collection, ids)
assert status.OK()
res = connect.count_entities(collection)
@ -91,36 +77,33 @@ class TestFlushBase:
method: add entities into partition in collection, flush serveral times
expected: the length of ids and the collection row count
# vector = gen_vector(nb, dim)
connect.create_partition(id_collection, tag)
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = connect.insert(id_collection, entities, ids)
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
ids = connect.insert(id_collection, default_entities, ids)
res_count = connect.count_entities(id_collection)
assert res_count == nb
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
assert len(ids) == nb
assert res_count == default_nb
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
assert len(ids) == default_nb
res_count = connect.count_entities(id_collection)
assert res_count == nb * 2
assert res_count == default_nb * 2
def test_add_partitions_flush(self, connect, id_collection):
method: add entities into partitions in collection, flush one
expected: the length of ids and the collection row count
# vectors = gen_vectors(nb, dim)
tag_new = gen_unique_str()
connect.create_partition(id_collection, tag)
connect.create_partition(id_collection, default_tag)
connect.create_partition(id_collection, tag_new)
ids = [i for i in range(nb)]
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
ids = [i for i in range(default_nb)]
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
ids = connect.insert(id_collection, entities, ids, partition_tag=tag_new)
ids = connect.insert(id_collection, default_entities, ids, partition_tag=tag_new)
res = connect.count_entities(id_collection)
assert res == 2 * nb
assert res == 2 * default_nb
def test_add_collections_flush(self, connect, id_collection):
@ -130,18 +113,17 @@ class TestFlushBase:
collection_new = gen_unique_str()
default_fields = gen_default_fields(False)
connect.create_collection(collection_new, default_fields)
connect.create_partition(id_collection, tag)
connect.create_partition(collection_new, tag)
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
ids = connect.insert(collection_new, entities, ids, partition_tag=tag)
connect.create_partition(id_collection, default_tag)
connect.create_partition(collection_new, default_tag)
ids = [i for i in range(default_nb)]
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
ids = connect.insert(collection_new, default_entities, ids, partition_tag=default_tag)
res = connect.count_entities(id_collection)
assert res == nb
assert res == default_nb
res = connect.count_entities(collection_new)
assert res == nb
assert res == default_nb
def test_add_collections_fields_flush(self, connect, id_collection, get_filter_field, get_vector_field):
@ -154,22 +136,21 @@ class TestFlushBase:
collection_new = gen_unique_str("test_flush")
fields = {
"fields": [filter_field, vector_field],
"segment_row_limit": segment_row_count,
"segment_row_limit": default_segment_row_limit,
"auto_id": False
connect.create_collection(collection_new, fields)
connect.create_partition(id_collection, tag)
connect.create_partition(collection_new, tag)
# vectors = gen_vectors(nb, dim)
entities_new = gen_entities_by_fields(fields["fields"], nb_new, dim)
ids = [i for i in range(nb)]
connect.create_partition(id_collection, default_tag)
connect.create_partition(collection_new, default_tag)
entities_new = gen_entities_by_fields(fields["fields"], nb_new, default_dim)
ids = [i for i in range(default_nb)]
ids_new = [i for i in range(nb_new)]
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
ids = connect.insert(collection_new, entities_new, ids_new, partition_tag=tag)
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
ids = connect.insert(collection_new, entities_new, ids_new, partition_tag=default_tag)
res = connect.count_entities(id_collection)
assert res == nb
assert res == default_nb
res = connect.count_entities(collection_new)
assert res == nb_new
@ -178,8 +159,7 @@ class TestFlushBase:
method: add entities, flush serveral times
expected: no error raised
# vectors = gen_vectors(nb, dim)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
for i in range(10):
res = connect.count_entities(collection)
@ -194,15 +174,14 @@ class TestFlushBase:
method: add entities
expected: no error raised
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = connect.insert(id_collection, entities, ids)
ids = [i for i in range(default_nb)]
ids = connect.insert(id_collection, default_entities, ids)
timeout = 20
start_time = time.time()
while (time.time() - start_time < timeout):
res = connect.count_entities(id_collection)
if res == nb:
if res == default_nb:
if time.time() - start_time > timeout:
assert False
@ -222,23 +201,21 @@ class TestFlushBase:
method: add entities, with same ids, count(same ids) < 15, > 15
expected: the length of ids and the collection row count
# vectors = gen_vectors(nb, dim)
ids = [i for i in range(nb)]
ids = [i for i in range(default_nb)]
for i, item in enumerate(ids):
if item <= same_ids:
ids[i] = 0
ids = connect.insert(id_collection, entities, ids)
ids = connect.insert(id_collection, default_entities, ids)
res = connect.count_entities(id_collection)
assert res == nb
assert res == default_nb
def test_delete_flush_multiable_times(self, connect, collection):
method: delete entities, flush serveral times
expected: no error raised
# vectors = gen_vectors(nb, dim)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
status = connect.delete_entity_by_id(collection, [ids[-1]])
assert status.OK()
for i in range(10):
@ -257,7 +234,7 @@ class TestFlushBase:
ids = []
for i in range(5):
tmp_ids = connect.insert(collection, entities)
tmp_ids = connect.insert(collection, default_entities)
@ -302,15 +279,19 @@ class TestFlushAsync:
status = future.result()
def test_flush_async_long(self, connect, collection):
# vectors = gen_vectors(nb, dim)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
future = connect.flush([collection], _async=True)
status = future.result()
def test_flush_async_long_drop_collection(self, connect, collection):
for i in range(5):
ids = connect.insert(collection, default_entities)
future = connect.flush([collection], _async=True)
def test_flush_async(self, connect, collection):
nb = 100000
vectors = gen_vectors(nb, dim)
connect.insert(collection, entities)
connect.insert(collection, default_entities)
future = connect.flush([collection], _async=True, _callback=self.check_status)
@ -7,26 +7,14 @@ import numpy
import pytest
import sklearn.preprocessing
from utils import *
from constants import *
nb = 6000
dim = 128
index_file_size = 10
uid = "test_index"
nprobe = 1
top_k = 5
tag = "1970_01_01"
NLIST = 4046
INVALID_NLIST = 100000000
field_name = "float_vector"
binary_field_name = "binary_vector"
collection_id = "index"
default_index_type = "FLAT"
entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
query, query_vecs = gen_query_vectors(field_name, entities, top_k, 1)
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 1024}, "metric_type": "L2"}
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1)
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
class TestIndexBase:
@ -46,7 +34,7 @@ class TestIndexBase:
def get_nq(self, request):
@ -65,9 +53,32 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
def test_create_index_on_field_not_existed(self, connect, collection, get_simple_index):
target: test create index interface
method: create collection and add entities in it, create index on field not existed
expected: error raised
tmp_field_name = gen_unique_str()
ids = connect.insert(collection, default_entities)
with pytest.raises(Exception) as e:
connect.create_index(collection, tmp_field_name, get_simple_index)
def test_create_index_on_field(self, connect, collection, get_simple_index):
target: test create index interface
method: create collection and add entities in it, create index on other field
expected: error raised
tmp_field_name = "int64"
ids = connect.insert(collection, default_entities)
with pytest.raises(Exception) as e:
connect.create_index(collection, tmp_field_name, get_simple_index)
def test_create_index_no_vectors(self, connect, collection, get_simple_index):
@ -84,8 +95,8 @@ class TestIndexBase:
method: create collection, create partition, and add entities in it, create index
expected: return search success
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.create_index(collection, field_name, get_simple_index)
@ -96,8 +107,8 @@ class TestIndexBase:
method: create collection, create partition, and add entities in it, create index
expected: return search success
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
connect.create_index(collection, field_name, get_simple_index)
@ -117,13 +128,13 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
connect.create_index(collection, field_name, get_simple_index)
nq = get_nq
index_type = get_simple_index["index_type"]
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param)
res = connect.search(collection, query)
assert len(res) == nq
@ -135,7 +146,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
def build(connect):
connect.create_index(collection, field_name, default_index)
@ -158,7 +169,7 @@ class TestIndexBase:
, make sure the collection name not in index
expected: create index failed
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
connect.create_index(collection_name, field_name, default_index)
@ -171,10 +182,10 @@ class TestIndexBase:
expected: create index ok, and count correct
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
count = connect.count_entities(collection)
assert count == nb
assert count == default_nb
@ -196,13 +207,13 @@ class TestIndexBase:
method: create another index with different index_params after index have been built
expected: return code 0, and describe index result equals with the second index params
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}]
for index in indexs:
connect.create_index(collection, field_name, index)
stats = connect.get_collection_stats(collection)
# assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"]
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
def test_create_index_ip(self, connect, collection, get_simple_index):
@ -211,7 +222,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
@ -232,8 +243,8 @@ class TestIndexBase:
method: create collection, create partition, and add entities in it, create index
expected: return search success
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
@ -245,8 +256,8 @@ class TestIndexBase:
method: create collection, create partition, and add entities in it, create index
expected: return search success
connect.create_partition(collection, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
get_simple_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
@ -259,14 +270,14 @@ class TestIndexBase:
expected: return search success
metric_type = "IP"
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
get_simple_index["metric_type"] = metric_type
connect.create_index(collection, field_name, get_simple_index)
nq = get_nq
index_type = get_simple_index["index_type"]
search_param = get_search_param(index_type)
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, search_params=search_param)
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param)
res = connect.search(collection, query)
assert len(res) == nq
@ -278,7 +289,7 @@ class TestIndexBase:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
def build(connect):
default_index["metric_type"] = "IP"
@ -302,7 +313,7 @@ class TestIndexBase:
, make sure the collection name not in index
expected: return code not equals to 0, create index failed
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
default_index["metric_type"] = "IP"
with pytest.raises(Exception) as e:
connect.create_index(collection_name, field_name, default_index)
@ -316,10 +327,10 @@ class TestIndexBase:
default_index["metric_type"] = "IP"
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
count = connect.count_entities(collection)
assert count == nb
assert count == default_nb
@ -342,13 +353,13 @@ class TestIndexBase:
method: create another index with different index_params after index have been built
expected: return code 0, and describe index result equals with the second index params
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}]
for index in indexs:
connect.create_index(collection, field_name, index)
stats = connect.get_collection_stats(collection)
# assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"]
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
@ -402,7 +413,7 @@ class TestIndexBase:
, make sure the collection name not in index, and then drop it
expected: return code not equals to 0, drop index failed
collection_name = gen_unique_str(collection_id)
collection_name = gen_unique_str(uid)
with pytest.raises(Exception) as e:
connect.drop_index(collection_name, field_name)
@ -526,7 +537,7 @@ class TestIndexBinary:
def get_nq(self, request):
@ -545,7 +556,7 @@ class TestIndexBinary:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
@ -555,8 +566,8 @@ class TestIndexBinary:
method: create collection, create partition, and add entities in it, create index
expected: return search success
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
@ -567,9 +578,9 @@ class TestIndexBinary:
expected: return search success
nq = get_nq
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq, metric_type="JACCARD")
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
res = connect.search(binary_collection, query, search_params=search_param)
@ -583,8 +594,9 @@ class TestIndexBinary:
expected: return create_index failure
# insert 6000 vectors
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
if get_l2_index["index_type"] == "BIN_FLAT":
res = connect.create_index(binary_collection, binary_field_name, get_l2_index)
@ -603,11 +615,11 @@ class TestIndexBinary:
method: create collection and add entities in it, create index, call describe index
expected: return code 0, and index instructure
ids = connect.insert(binary_collection, binary_entities)
ids = connect.insert(binary_collection, default_binary_entities)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
for partition in stats["partitions"]:
segments = partition["segments"]
if segments:
@ -622,13 +634,13 @@ class TestIndexBinary:
method: create collection, create partition and add entities in it, create index, call describe index
expected: return code 0, and index instructure
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
assert len(stats["partitions"]) == 2
for partition in stats["partitions"]:
segments = partition["segments"]
@ -664,14 +676,14 @@ class TestIndexBinary:
method: create collection, create partition and add entities in it, create index on collection, call drop collection index
expected: return code 0, and default index param
connect.create_partition(binary_collection, tag)
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
stats = connect.get_collection_stats(binary_collection)
connect.drop_index(binary_collection, binary_field_name)
stats = connect.get_collection_stats(binary_collection)
assert stats["row_count"] == nb
assert stats["row_count"] == default_nb
for partition in stats["partitions"]:
segments = partition["segments"]
if segments:
@ -714,7 +726,7 @@ class TestIndexInvalid(object):
def get_index(self, request):
yield request.param
def test_create_index_with_invalid_index_params(self, connect, collection, get_index):
with pytest.raises(Exception) as e:
@ -760,7 +772,7 @@ class TestIndexAsync:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
logging.getLogger().info("start index")
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
logging.getLogger().info("before result")
@ -768,6 +780,20 @@ class TestIndexAsync:
def test_create_index_drop(self, connect, collection, get_simple_index):
target: test create index interface
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, default_entities)
logging.getLogger().info("start index")
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
def test_create_index_with_invalid_collectionname(self, connect):
collection_name = " "
future = connect.create_index(collection_name, field_name, default_index, _async=True)
@ -781,7 +807,7 @@ class TestIndexAsync:
method: create collection and add entities in it, create index
expected: return search success
ids = connect.insert(collection, entities)
ids = connect.insert(collection, default_entities)
logging.getLogger().info("start index")
future = connect.create_index(collection, field_name, get_simple_index, _async=True,
@ -9,11 +9,8 @@ from multiprocessing import Process
import sklearn.preprocessing
from utils import *
dim = 128
index_file_size = 10
collection_id = "test_mix"
add_interval_time = 5
vectors = gen_vectors(10000, dim)
vectors = gen_vectors(10000, default_dim)
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = vectors.tolist()
top_k = 1
@ -24,7 +21,6 @@ nlist = 128
class TestMixBase:
# disable
def _test_search_during_createIndex(self, args):
loops = 10000
@ -35,7 +31,7 @@ class TestMixBase:
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
milvus_instance.create_collection({'collection_name': collection,
'dimension': dim,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': "L2"})
for i in range(10):
@ -88,7 +84,7 @@ class TestMixBase:
collection_name = gen_unique_str('test_mix_multi_collections')
param = {'collection_name': collection_name,
'dimension': dim,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
@ -101,7 +97,7 @@ class TestMixBase:
collection_name = gen_unique_str('test_mix_multi_collections')
param = {'collection_name': collection_name,
'dimension': dim,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
@ -6,26 +6,11 @@ import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
from constants import *
dim = 128
segment_row_count = 5000
collection_id = "partition"
nprobe = 1
tag = "1970_01_01"
nb = 6000
tag = "partition_tag"
field_name = "float_vector"
entity = gen_entities(1)
entities = gen_entities(nb)
raw_vector, binary_entity = gen_binary_entities(1)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
class TestCreateBase:
The following cases are used to test `create_partition` function
@ -37,22 +22,24 @@ class TestCreateBase:
method: call function: create_partition
expected: status ok
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
# TODO: enable
def test_create_partition_limit(self, connect, collection, args):
target: test create partitions, check status returned
method: call function: create_partition for 4097 times
expected: exception raised
threads_num = 16
threads_num = 8
threads = []
if args["handler"] == "HTTP":
pytest.skip("skip in http mode")
def create(connect, threads_num):
for i in range(4096 // threads_num):
for i in range(max_partition_num // threads_num):
tag_tmp = gen_unique_str()
connect.create_partition(collection, tag_tmp)
@ -73,9 +60,9 @@ class TestCreateBase:
method: call function: create_partition
expected: status ok
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
def test_create_partition_collection_not_existed(self, connect):
@ -85,7 +72,7 @@ class TestCreateBase:
collection_name = gen_unique_str()
with pytest.raises(Exception) as e:
connect.create_partition(collection_name, tag)
connect.create_partition(collection_name, default_tag)
def test_create_partition_tag_name_None(self, connect, collection):
@ -103,11 +90,11 @@ class TestCreateBase:
method: call function: create_partition, and again
expected: status ok
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
tag_name = gen_unique_str()
connect.create_partition(collection, tag_name)
tag_list = connect.list_partitions(collection)
assert tag in tag_list
assert default_tag in tag_list
assert tag_name in tag_list
assert "_default" in tag_list
@ -117,9 +104,9 @@ class TestCreateBase:
method: call function: create_partition
expected: status ok
connect.create_partition(id_collection, tag)
ids = [i for i in range(nb)]
insert_ids = connect.insert(id_collection, entities, ids)
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.insert(id_collection, default_entities, ids)
assert len(insert_ids) == len(ids)
def test_create_partition_insert_with_tag(self, connect, id_collection):
@ -128,9 +115,9 @@ class TestCreateBase:
method: call function: create_partition
expected: status ok
connect.create_partition(id_collection, tag)
ids = [i for i in range(nb)]
insert_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
assert len(insert_ids) == len(ids)
def test_create_partition_insert_with_tag_not_existed(self, connect, collection):
@ -140,10 +127,10 @@ class TestCreateBase:
expected: status not ok
tag_new = "tag_new"
connect.create_partition(collection, tag)
ids = [i for i in range(nb)]
connect.create_partition(collection, default_tag)
ids = [i for i in range(default_nb)]
with pytest.raises(Exception) as e:
insert_ids = connect.insert(collection, entities, ids, partition_tag=tag_new)
insert_ids = connect.insert(collection, default_entities, ids, partition_tag=tag_new)
def test_create_partition_insert_same_tags(self, connect, id_collection):
@ -151,14 +138,14 @@ class TestCreateBase:
method: call function: create_partition
expected: status ok
connect.create_partition(id_collection, tag)
ids = [i for i in range(nb)]
insert_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
ids = [(i+nb) for i in range(nb)]
new_insert_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
connect.create_partition(id_collection, default_tag)
ids = [i for i in range(default_nb)]
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
ids = [(i+default_nb) for i in range(default_nb)]
new_insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
res = connect.count_entities(id_collection)
assert res == nb * 2
assert res == default_nb * 2
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
@ -167,17 +154,17 @@ class TestCreateBase:
method: call function: create_partition
expected: status ok, collection length is correct
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
collection_new = gen_unique_str()
connect.create_collection(collection_new, default_fields)
connect.create_partition(collection_new, tag)
ids = connect.insert(collection, entities, partition_tag=tag)
ids = connect.insert(collection_new, entities, partition_tag=tag)
connect.create_partition(collection_new, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
ids = connect.insert(collection_new, default_entities, partition_tag=default_tag)
connect.flush([collection, collection_new])
res = connect.count_entities(collection)
assert res == nb
assert res == default_nb
res = connect.count_entities(collection_new)
assert res == nb
assert res == default_nb
class TestShowBase:
@ -193,9 +180,9 @@ class TestShowBase:
method: create partition first, then call function: list_partitions
expected: status ok, partition correct
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
res = connect.list_partitions(collection)
assert tag in res
assert default_tag in res
def test_list_partitions_no_partition(self, connect, collection):
@ -213,10 +200,10 @@ class TestShowBase:
expected: status ok, partitions correct
tag_new = gen_unique_str()
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.create_partition(collection, tag_new)
res = connect.list_partitions(collection)
assert tag in res
assert default_tag in res
assert tag_new in res
@ -240,8 +227,8 @@ class TestHasBase:
method: create partition first, then call function: has_partition
expected: status ok, result true
connect.create_partition(collection, tag)
res = connect.has_partition(collection, tag)
connect.create_partition(collection, default_tag)
res = connect.has_partition(collection, default_tag)
assert res
@ -251,9 +238,9 @@ class TestHasBase:
method: create partition first, then call function: has_partition
expected: status ok, result true
for tag_name in [tag, "tag_new", "tag_new_new"]:
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
connect.create_partition(collection, tag_name)
for tag_name in [tag, "tag_new", "tag_new_new"]:
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
res = connect.has_partition(collection, tag_name)
assert res
@ -263,7 +250,7 @@ class TestHasBase:
method: then call function: has_partition, with tag not existed
expected: status ok, result empty
res = connect.has_partition(collection, tag)
res = connect.has_partition(collection, default_tag)
assert not res
@ -274,7 +261,7 @@ class TestHasBase:
expected: status not ok
with pytest.raises(Exception) as e:
res = connect.has_partition("not_existed_collection", tag)
res = connect.has_partition("not_existed_collection", default_tag)
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
@ -284,13 +271,9 @@ class TestHasBase:
expected: status ok
tag_name = get_tag_name
connect.create_partition(collection, tag)
if isinstance(tag_name, str):
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
res = connect.has_partition(collection, tag_name)
assert not res
with pytest.raises(Exception) as e:
res = connect.has_partition(collection, tag_name)
class TestDropBase:
@ -306,11 +289,11 @@ class TestDropBase:
method: create partitions first, then call function: drop_partition
expected: status ok, no partitions in db
connect.create_partition(collection, tag)
connect.drop_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.drop_partition(collection, default_tag)
res = connect.list_partitions(collection)
tag_list = []
assert tag not in tag_list
assert default_tag not in tag_list
def test_drop_partition_tag_not_existed(self, connect, collection):
@ -318,7 +301,7 @@ class TestDropBase:
method: create partitions first, then call function: drop_partition
expected: status not ok
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
new_tag = "new_tag"
with pytest.raises(Exception) as e:
connect.drop_partition(collection, new_tag)
@ -329,10 +312,10 @@ class TestDropBase:
method: create partitions first, then call function: drop_partition
expected: status not ok
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
new_collection = gen_unique_str()
with pytest.raises(Exception) as e:
connect.drop_partition(new_collection, tag)
connect.drop_partition(new_collection, default_tag)
def test_drop_partition_repeatedly(self, connect, collection):
@ -341,13 +324,13 @@ class TestDropBase:
method: create partitions first, then call function: drop_partition
expected: status not ok, no partitions in db
connect.create_partition(collection, tag)
connect.drop_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.drop_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.drop_partition(collection, tag)
connect.drop_partition(collection, default_tag)
tag_list = connect.list_partitions(collection)
assert tag not in tag_list
assert default_tag not in tag_list
def test_drop_partition_create(self, connect, collection):
@ -355,12 +338,12 @@ class TestDropBase:
method: create partitions first, then call function: drop_partition, create_partition
expected: status not ok, partition in db
connect.create_partition(collection, tag)
connect.drop_partition(collection, tag)
connect.create_partition(collection, default_tag)
connect.drop_partition(collection, default_tag)
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
tag_list = connect.list_partitions(collection)
assert tag in tag_list
assert default_tag in tag_list
class TestNameInvalid(object):
@ -378,6 +361,7 @@ class TestNameInvalid(object):
def get_collection_name(self, request):
yield request.param
def test_drop_partition_with_invalid_collection_name(self, connect, collection, get_collection_name):
target: test drop partition, with invalid collection name, check status returned
@ -385,10 +369,11 @@ class TestNameInvalid(object):
expected: status not ok
collection_name = get_collection_name
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.drop_partition(collection_name, tag)
connect.drop_partition(collection_name, default_tag)
def test_drop_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
target: test drop partition, with invalid tag name, check status returned
@ -396,10 +381,11 @@ class TestNameInvalid(object):
expected: status not ok
tag_name = get_tag_name
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
connect.drop_partition(collection, tag_name)
def test_list_partitions_with_invalid_collection_name(self, connect, collection, get_collection_name):
target: test show partitions, with invalid collection name, check status returned
@ -407,6 +393,6 @@ class TestNameInvalid(object):
expected: status not ok
collection_name = get_collection_name
connect.create_partition(collection, tag)
connect.create_partition(collection, default_tag)
with pytest.raises(Exception) as e:
res = connect.list_partitions(collection_name)
@ -1,49 +0,0 @@
import time
import pdb
import threading
import logging
from multiprocessing import Pool, Process
import pytest
from utils import *
dim = 128
collection_id = "test_wal"
segment_row_count = 5000
tag = "1970_01_01"
insert_interval_time = 1.5
nb = 6000
field_name = "float_vector"
entity = gen_entities(1)
binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
class TestWalBase:
The following cases are used to test WAL functionality
def test_wal_server_crashed_recovery(self, connect, collection):
target: test wal when server crashed unexpectedly and restarted
method: add vectors, server killed before flush, restarted server and flush
expected: status ok, add request is recovered and vectors added
ids = connect.insert(collection, entity)
res = connect.count_entities(collection)
logging.getLogger().info(res) # should be 0 because no auto flush
logging.getLogger().info("Stop server and restart")
# kill server and restart. auto flush should be set to 15 seconds.
# time.sleep(15)
res = connect.count_entities(collection)
assert res == 1
res = connect.get_entity_by_id(collection, [ids[0]])
@ -13,16 +13,25 @@ from milvus import Milvus, DataType
port = 19530
epsilon = 0.000001
namespace = "milvus"
default_flush_interval = 1
big_flush_interval = 1000
dimension = 128
nb = 6000
top_k = 10
segment_row_count = 5000
default_drop_interval = 3
default_dim = 128
default_nb = 1200
default_top_k = 10
max_top_k = 16384
max_partition_num = 256
default_segment_row_limit = 1000
default_server_segment_row_limit = 1024 * 512
default_float_vec_field_name = "float_vector"
default_binary_vec_field_name = "binary_vector"
default_partition_name = "_default"
default_tag = "1970_01_01"
# TODO: disable RHNSW_SQ/PQ in 0.11.0
all_index_types = [
@ -32,6 +41,8 @@ all_index_types = [
# "NSG",
@ -45,6 +56,8 @@ default_index_params = [
{"M": 48, "efConstruction": 500},
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
{"n_trees": 50},
{"M": 48, "efConstruction": 500, "PQM": 64},
{"M": 48, "efConstruction": 500},
{"nlist": 128},
{"nlist": 128}
@ -66,6 +79,10 @@ def ivf():
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
def skip_pq():
return ["IVF_PQ", "RHNSW_PQ", "RHNSW_SQ"]
def binary_metrics():
@ -123,6 +140,10 @@ def get_milvus(host, port, uri=None, handler=None, **kwargs):
return milvus
def reset_build_index_threshold(connect):
connect.set_config("engine", "build_index_threshold", 1024)
def disable_flush(connect):
connect.set_config("storage", "auto_flush_interval", big_flush_interval)
@ -204,14 +225,14 @@ def gen_single_filter_fields():
fields = []
for data_type in DataType:
if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
fields.append({"field": data_type.name, "type": data_type})
fields.append({"name": data_type.name, "type": data_type})
return fields
def gen_single_vector_fields():
fields = []
for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
field = {"field": data_type.name, "type": data_type, "params": {"dim": dimension}}
field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}}
return fields
@ -219,11 +240,11 @@ def gen_single_vector_fields():
def gen_default_fields(auto_id=True):
default_fields = {
"fields": [
{"field": "int64", "type": DataType.INT64},
{"field": "float", "type": DataType.FLOAT},
{"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": dimension}},
{"name": "int64", "type": DataType.INT64},
{"name": "float", "type": DataType.FLOAT},
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": default_dim}},
"segment_row_limit": segment_row_count,
"segment_row_limit": default_segment_row_limit,
"auto_id" : auto_id
return default_fields
@ -232,37 +253,37 @@ def gen_default_fields(auto_id=True):
def gen_binary_default_fields(auto_id=True):
default_fields = {
"fields": [
{"field": "int64", "type": DataType.INT64},
{"field": "float", "type": DataType.FLOAT},
{"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": dimension}}
{"name": "int64", "type": DataType.INT64},
{"name": "float", "type": DataType.FLOAT},
{"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}}
"segment_row_limit": segment_row_count,
"segment_row_limit": default_segment_row_limit,
"auto_id" : auto_id
return default_fields
def gen_entities(nb, is_normal=False):
vectors = gen_vectors(nb, dimension, is_normal)
vectors = gen_vectors(nb, default_dim, is_normal)
entities = [
{"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
{"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
{"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
return entities
def gen_binary_entities(nb):
raw_vectors, vectors = gen_binary_vectors(nb, dimension)
raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
entities = [
{"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
{"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
{"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
{"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
return raw_vectors, entities
def gen_entities_by_fields(fields, nb, dimension):
def gen_entities_by_fields(fields, nb, dim):
entities = []
for field in fields:
if field["type"] in [DataType.INT32, DataType.INT64]:
@ -270,9 +291,9 @@ def gen_entities_by_fields(fields, nb, dimension):
elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
field_value = [3.0 for i in range(nb)]
elif field["type"] == DataType.BINARY_VECTOR:
field_value = gen_binary_vectors(nb, dimension)[1]
field_value = gen_binary_vectors(nb, dim)[1]
elif field["type"] == DataType.FLOAT_VECTOR:
field_value = gen_vectors(nb, dimension)
field_value = gen_vectors(nb, dim)
field.update({"values": field_value})
return entities
@ -316,7 +337,7 @@ def gen_default_vector_expr(default_query):
def gen_default_term_expr(keyword="term", field="int64", values=None):
if values is None:
values = [i for i in range(nb // 2)]
values = [i for i in range(default_nb // 2)]
expr = {keyword: {field: {"values": values}}}
return expr
@ -330,7 +351,7 @@ def update_term_expr(src_term, terms):
def gen_default_range_expr(keyword="range", field="int64", ranges=None):
if ranges is None:
ranges = {"GT": 1, "LT": nb // 2}
ranges = {"GT": 1, "LT": default_nb // 2}
expr = {keyword: {field: ranges}}
return expr
@ -347,18 +368,18 @@ def gen_invalid_range():
{"range": 1},
{"range": {}},
{"range": []},
{"range": {"range": {"int64": {"GT": 0, "LT": nb // 2}}}}
{"range": {"range": {"int64": {"GT": 0, "LT": default_nb // 2}}}}
return range
def gen_valid_ranges():
ranges = [
{"GT": 0, "LT": nb//2},
{"GT": nb // 2, "LT": nb*2},
{"GT": 0, "LT": default_nb//2},
{"GT": default_nb // 2, "LT": default_nb * 2},
{"GT": 0},
{"LT": nb},
{"GT": -1, "LT": top_k},
{"LT": default_nb},
{"GT": -1, "LT": default_top_k},
return ranges
@ -368,7 +389,7 @@ def gen_invalid_term():
{"term": 1},
{"term": []},
{"term": {}},
{"term": {"term": {"int64": {"values": [i for i in range(nb // 2)]}}}}
{"term": {"term": {"int64": {"values": [i for i in range(default_nb // 2)]}}}}
return terms
@ -378,7 +399,7 @@ def add_field_default(default_fields, type=DataType.INT64, field_name=None):
if field_name is None:
field_name = gen_unique_str()
field = {
"field": field_name,
"name": field_name,
"type": type
@ -391,7 +412,7 @@ def add_field(entities, field_name=None):
if field_name is None:
field_name = gen_unique_str()
field = {
"field": field_name,
"name": field_name,
"type": DataType.INT64,
"values": [i for i in range(nb)]
@ -401,9 +422,9 @@ def add_field(entities, field_name=None):
def add_vector_field(entities, is_normal=False):
nb = len(entities[0]["values"])
vectors = gen_vectors(nb, dimension, is_normal)
vectors = gen_vectors(nb, default_dim, is_normal)
field = {
"field": gen_unique_str(),
"name": gen_unique_str(),
"type": DataType.FLOAT_VECTOR,
"values": vectors
@ -434,15 +455,15 @@ def remove_vector_field(entities):
def update_field_name(entities, old_name, new_name):
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["field"] == old_name:
item["field"] = new_name
if item["name"] == old_name:
item["name"] = new_name
return tmp_entities
def update_field_type(entities, old_name, new_name):
tmp_entities = copy.deepcopy(entities)
for item in tmp_entities:
if item["field"] == old_name:
if item["name"] == old_name:
item["type"] = new_name
return tmp_entities
@ -456,21 +477,20 @@ def update_field_value(entities, old_type, new_value):
return tmp_entities
def add_vector_field(nb, dimension=dimension):
def add_vector_field(nb, dimension=default_dim):
field_name = gen_unique_str()
field = {
"field": field_name,
"name": field_name,
"type": DataType.FLOAT_VECTOR,
"values": gen_vectors(nb, dimension)
return field_name
def gen_segment_row_counts():
def gen_segment_row_limits():
sizes = [
return sizes
@ -534,12 +554,8 @@ def gen_invalid_strs():
# "",
# None,
"12 s",
" siede ",
"a".join("a" for i in range(256))
@ -616,12 +632,7 @@ def gen_invalid_vectors():
" ",
" siede ",
"a".join("a" for i in range(256))
@ -639,7 +650,7 @@ def gen_invaild_search_params():
for nprobe in gen_invalid_params():
ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
elif index_type == "HNSW":
elif index_type in ["HNSW", "RHNSW_PQ", "RHNSW_SQ"]:
for ef in gen_invalid_params():
hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
@ -667,9 +678,13 @@ def gen_invalid_index():
for M in gen_invalid_params():
index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
index_param = {"index_type": "RHNSW_PQ", "params": {"M": M, "efConstruction": 100}}
index_param = {"index_type": "RHNSW_SQ", "params": {"M": M, "efConstruction": 100}}
for efConstruction in gen_invalid_params():
index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
index_param = {"index_type": "RHNSW_PQ", "params": {"M": 16, "efConstruction": efConstruction}}
index_param = {"index_type": "RHNSW_SQ", "params": {"M": 16, "efConstruction": efConstruction}}
for search_length in gen_invalid_params():
index_param = {"index_type": "NSG",
@ -688,6 +703,8 @@ def gen_invalid_index():
index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "RHNSW_PQ", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "RHNSW_SQ", "params": {"invalid_key": 16, "efConstruction": 100}})
index_params.append({"index_type": "NSG",
"params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
"knng": 100}})
@ -720,7 +737,7 @@ def gen_index():
for nlist in nlists \
for m in pq_ms]
elif index_type == "HNSW":
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
for M in Ms \
for efConstruction in efConstructions]
@ -763,7 +780,7 @@ def get_search_param(index_type, metric_type="L2"):
search_params = {"metric_type": metric_type}
if index_type in ivf() or index_type in binary_support():
search_params.update({"nprobe": 64})
elif index_type == "HNSW":
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
search_params.update({"ef": 64})
elif index_type == "NSG":
search_params.update({"search_length": 100})
@ -788,7 +805,6 @@ def restart_server(helm_release_name):
from kubernetes import client, config
namespace = "milvus"
# service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
v1 = client.CoreV1Api()
@ -802,7 +818,7 @@ def restart_server(helm_release_name):
# v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
# status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
# print(status_res)
logging.getLogger().debug("Pod name: %s" % pod_name)
if pod_name is not None:
v1.delete_namespaced_pod(pod_name, namespace)
@ -811,24 +827,61 @@ def restart_server(helm_release_name):
logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
res = False
return res
logging.error("Sleep 10s after pod deleted")
# check if restart successfully
pods = v1.list_namespaced_pod(namespace)
for i in pods.items:
pod_name_tmp = i.metadata.name
if pod_name_tmp.find(helm_release_name) != -1:
if pod_name_tmp == pod_name:
elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1:
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
start_time = time.time()
while time.time() - start_time > timeout:
ready_break = False
while time.time() - start_time <= timeout:
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
if status_res.status.phase == "Running":
logging.error("Already running")
ready_break = True
if time.time() - start_time > timeout:
logging.error("Restart pod: %s timeout" % pod_name_tmp)
res = False
return res
if ready_break:
logging.error("Pod: %s not found" % helm_release_name)
res = False
raise Exception("Pod: %s not found" % pod_name)
follow = True
pretty = True
previous = True # bool | Return previous terminated container logs. Defaults to false. (optional)
since_seconds = 56 # int | A relative time in seconds before the current time from which to show logs. If this value precedes the time a pod was started, only logs since the pod start will be returned. If this value is in the future, no logs will be returned. Only one of sinceSeconds or sinceTime may be specified. (optional)
timestamps = True # bool | If true, add an RFC3339 or RFC3339Nano timestamp at the beginning of every line of log output. Defaults to false. (optional)
container = "milvus"
# start_time = time.time()
# while time.time() - start_time <= timeout:
# try:
# api_response = v1.read_namespaced_pod_log(pod_name_tmp, namespace, container=container, follow=follow,
# pretty=pretty, previous=previous, since_seconds=since_seconds,
# timestamps=timestamps)
# logging.error(api_response)
# return res
# except Exception as e:
# logging.error("Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
# # waiting for server start
# time.sleep(5)
# # res = False
# # return res
# if time.time() - start_time > timeout:
# logging.error("Restart pod: %s timeout" % pod_name_tmp)
# res = False
return res
Reference in New Issue