mirror of https://github.com/milvus-io/milvus.git
cp test from 0.11.0
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>pull/3980/head^2
parent
ecb01e0c49
commit
fff5da0de4
|
@ -10,59 +10,59 @@
|
|||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.10" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.68" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.8.0-SNAPSHOT" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.27.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.11.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.guava:guava:28.1-android" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.11" level="project" />
|
||||
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.testng:testng:6.14.3" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.beust:jcommander:1.72" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache-extras.beanshell:bsh:2.0b6" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.68" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.73" level="project" />
|
||||
<orderEntry type="library" name="Maven: junit:junit:4.13" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.8.0-SNAPSHOT" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven.plugins:maven-gpg-plugin:1.6" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-api:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-project:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-settings:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-profile:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-artifact-manager:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven.wagon:wagon-provider-api:1.0-beta-6" level="project" />
|
||||
<orderEntry type="library" name="Maven: backport-util-concurrent:backport-util-concurrent:3.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-registry:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-interpolation:1.11" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-container-default:1.0-alpha-9-stable-1" level="project" />
|
||||
<orderEntry type="library" name="Maven: classworlds:classworlds:1.1-alpha-2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-artifact:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-repository-metadata:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.maven:maven-model:2.2.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-utils:3.0.20" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-sec-dispatcher:1.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-cipher:1.4" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.27.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.27.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.android:annotations:4.1.1.4" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.perfmark:perfmark-api:0.19.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.9.0-SNAPSHOT" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.30.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.30.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.30.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.code.findbugs:jsr305:3.0.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.codehaus.mojo:animal-sniffer-annotations:1.18" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.11.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.guava:guava:28.1-android" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.12.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.30.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.guava:guava:28.2-android" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.guava:failureaccess:1.0.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.guava:listenablefuture:9999.0-empty-to-avoid-conflict-with-guava" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.checkerframework:checker-compat-qual:2.5.5" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.j2objc:j2objc-annotations:1.3" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.27.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java-util:3.11.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.errorprone:error_prone_annotations:2.3.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.6" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.4" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: org.codehaus.mojo:animal-sniffer-annotations:1.18" level="project" />
|
||||
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.30.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.30.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.30.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.android:annotations:4.1.1.4" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: io.perfmark:perfmark-api:0.19.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.json:json:20190722" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.30" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.12.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.12.1" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.12.1" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.testcontainers:testcontainers:1.14.3" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.jetbrains:annotations:19.0.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.20" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.rnorth.duct-tape:duct-tape:1.0.8" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.rnorth.visible-assertions:visible-assertions:2.1.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.rnorth:tcp-unix-socket-proxy:1.0.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.kohlschutter.junixsocket:junixsocket-native-common:2.0.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.scijava:native-lib-loader:2.0.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.kohlschutter.junixsocket:junixsocket-common:2.0.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: net.java.dev.jna:jna-platform:5.5.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.5.0" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.slf4j:slf4j-simple:1.7.30" level="project" />
|
||||
</component>
|
||||
</module>
|
|
@ -91,7 +91,7 @@
|
|||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>7.3.0</version>
|
||||
<version>6.14.3</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
@ -118,6 +118,17 @@
|
|||
<artifactId>milvus-sdk-java</artifactId>
|
||||
<version>0.9.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>testcontainers</artifactId>
|
||||
<version>1.14.3</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
<version>1.7.30</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.grpc</groupId>-->
|
||||
|
|
|
@ -3,6 +3,7 @@ package com;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import io.milvus.client.*;
|
||||
|
||||
public final class Constants {
|
||||
|
||||
|
@ -20,37 +21,39 @@ public final class Constants {
|
|||
|
||||
public static final double epsilon = 0.001;
|
||||
|
||||
public static final int segmentRowLimit = 5000;
|
||||
|
||||
public static final String fieldNameKey = "name";
|
||||
|
||||
public static final String vectorType = "float";
|
||||
|
||||
public static final String defaultMetricType = "L2";
|
||||
public static final String binaryVectorType = "binary";
|
||||
|
||||
public static final String indexType = "IVF_SQ8";
|
||||
public static final MetricType defaultMetricType = MetricType.L2;
|
||||
|
||||
public static final String defaultIndexType = "FLAT";
|
||||
public static final IndexType indexType = IndexType.IVF_SQ8;
|
||||
|
||||
public static final String defaultBinaryIndexType = "BIN_FLAT";
|
||||
public static final IndexType defaultIndexType = IndexType.FLAT;
|
||||
|
||||
public static final String defaultBinaryMetricType = "JACCARD";
|
||||
public static final IndexType defaultBinaryIndexType = IndexType.BIN_IVF_FLAT;
|
||||
|
||||
public static final String floatFieldName = "float_vector";
|
||||
public static final MetricType defaultBinaryMetricType = MetricType.JACCARD;
|
||||
|
||||
public static final String binaryFieldName = "binary_vector";
|
||||
public static final String intFieldName = "int64";
|
||||
|
||||
public static final String indexParam = Utils.setIndexParam(indexType, "L2", n_list);
|
||||
public static final String floatFieldName = "float";
|
||||
|
||||
public static final String binaryIndexParam = Utils.setIndexParam(defaultBinaryIndexType, defaultBinaryMetricType, n_list);
|
||||
public static final String floatVectorFieldName = "float_vector";
|
||||
|
||||
public static final String binaryVectorFieldName = "binary_vector";
|
||||
|
||||
public static final List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
|
||||
|
||||
public static final List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
|
||||
|
||||
public static final List<Map<String,Object>> defaultFields = Utils.genDefaultFields(dimension,false);
|
||||
public static final Map<String, List> defaultEntities = Utils.genDefaultEntities(nb, vectors);
|
||||
|
||||
public static final List<Map<String,Object>> defaultBinaryFields = Utils.genDefaultFields(dimension,true);
|
||||
|
||||
public static final List<Map<String,Object>> defaultEntities = Utils.genDefaultEntities(dimension, nb, vectors);
|
||||
|
||||
public static final List<Map<String,Object>> defaultBinaryEntities = Utils.genDefaultBinaryEntities(dimension, nb, vectorsBinary);
|
||||
public static final Map<String, List> defaultBinaryEntities = Utils.genDefaultBinaryEntities(nb, vectorsBinary);
|
||||
|
||||
public static final String searchParam = Utils.setSearchParam(defaultMetricType, vectors.subList(0, nq), topk, n_probe);
|
||||
|
||||
|
|
|
@ -3,11 +3,7 @@ package com;
|
|||
import io.milvus.client.*;
|
||||
import org.apache.commons.cli.*;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.SkipException;
|
||||
import org.testng.TestNG;
|
||||
import org.testng.annotations.AfterMethod;
|
||||
import org.testng.annotations.AfterSuite;
|
||||
import org.testng.annotations.AfterTest;
|
||||
import org.testng.annotations.DataProvider;
|
||||
import org.testng.xml.XmlClass;
|
||||
import org.testng.xml.XmlSuite;
|
||||
|
@ -15,16 +11,17 @@ import org.testng.xml.XmlTest;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class MainClass {
|
||||
private static String HOST = "127.0.0.1";
|
||||
// private static String HOST = "192.168.1.238";
|
||||
private static int PORT = 19530;
|
||||
private int segmentRowCount = 5000;
|
||||
private static ConnectParam CONNECT_PARAM = new ConnectParam.Builder()
|
||||
.withHost(HOST)
|
||||
.withPort(PORT)
|
||||
.build();
|
||||
private static MilvusClient client;
|
||||
|
||||
public static void setHost(String host) {
|
||||
MainClass.HOST = host;
|
||||
|
@ -34,98 +31,85 @@ public class MainClass {
|
|||
MainClass.PORT = port;
|
||||
}
|
||||
|
||||
public static String getHost() {
|
||||
return MainClass.HOST;
|
||||
}
|
||||
|
||||
public static int getPort() {
|
||||
return MainClass.PORT;
|
||||
}
|
||||
|
||||
@DataProvider(name="DefaultConnectArgs")
|
||||
public static Object[][] defaultConnectArgs(){
|
||||
return new Object[][]{{HOST, PORT}};
|
||||
}
|
||||
|
||||
@DataProvider(name="ConnectInstance")
|
||||
public Object[][] connectInstance() throws ConnectFailedException {
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
public Object[][] connectInstance() throws Exception {
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(HOST)
|
||||
.withPort(PORT)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
client = new MilvusGrpcClient(connectParam).withLogging();
|
||||
String collectionName = RandomStringUtils.randomAlphabetic(10);
|
||||
return new Object[][]{{client, collectionName}};
|
||||
}
|
||||
|
||||
@DataProvider(name="DisConnectInstance")
|
||||
public Object[][] disConnectInstance() throws ConnectFailedException {
|
||||
public Object[][] disConnectInstance(){
|
||||
// Generate connection instance
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
client.connect(CONNECT_PARAM);
|
||||
try {
|
||||
client.disconnect();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
client = new MilvusGrpcClient(CONNECT_PARAM).withLogging();
|
||||
client.close();
|
||||
String collectionName = RandomStringUtils.randomAlphabetic(10);
|
||||
return new Object[][]{{client, collectionName}};
|
||||
}
|
||||
|
||||
private Object[][] genCollection(boolean isBinary, boolean autoId) throws ConnectFailedException {
|
||||
private Object[][] genCollection(boolean isBinary, boolean autoId) throws Exception {
|
||||
Object[][] collection;
|
||||
String collectionName = Utils.genUniqueStr("collection");
|
||||
List<Map<String, Object>> defaultFields = Utils.genDefaultFields(Constants.dimension,isBinary);
|
||||
String jsonParams = String.format("{\"segment_row_count\": %s, \"auto_id\": %s}",segmentRowCount, autoId);
|
||||
// Generate connection instance
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
client.connect(CONNECT_PARAM);
|
||||
CollectionMapping cm = new CollectionMapping.Builder(collectionName)
|
||||
.withFields(defaultFields)
|
||||
.withParamsInJson(jsonParams)
|
||||
.build();
|
||||
Response res = client.createCollection(cm);
|
||||
if (!res.ok()) {
|
||||
System.out.println(res.getMessage());
|
||||
throw new SkipException("Collection created failed");
|
||||
client = new MilvusGrpcClient(CONNECT_PARAM).withLogging();
|
||||
CollectionMapping cm = CollectionMapping
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT)
|
||||
.setParamsInJson(new JsonBuilder()
|
||||
.param("segment_row_limit", segmentRowCount)
|
||||
.param("auto_id", autoId)
|
||||
.build());
|
||||
if (isBinary) {
|
||||
cm.addVectorField("binary_vector", DataType.VECTOR_BINARY, Constants.dimension);
|
||||
} else {
|
||||
cm.addVectorField("float_vector", DataType.VECTOR_FLOAT, Constants.dimension);
|
||||
}
|
||||
client.createCollection(cm);
|
||||
collection = new Object[][]{{client, collectionName}};
|
||||
return collection;
|
||||
}
|
||||
|
||||
@DataProvider(name="Collection")
|
||||
public Object[][] provideCollection() throws ConnectFailedException, InterruptedException {
|
||||
public Object[][] provideCollection() throws Exception, InterruptedException {
|
||||
Object[][] collection = genCollection(false,true);
|
||||
return collection;
|
||||
// List<String> tableNames = client.listCollections().getCollectionNames();
|
||||
// for (int j = 0; j < tableNames.size(); ++j
|
||||
// ) {
|
||||
// client.dropCollection(tableNames.get(j));
|
||||
// }
|
||||
// Thread.currentThread().sleep(2000);
|
||||
}
|
||||
@DataProvider(name="IdCollection")
|
||||
public Object[][] provideIdCollection() throws ConnectFailedException, InterruptedException {
|
||||
public Object[][] provideIdCollection() throws Exception, InterruptedException {
|
||||
Object[][] idCollection = genCollection(false,false);
|
||||
return idCollection;
|
||||
}
|
||||
|
||||
@DataProvider(name="BinaryCollection")
|
||||
public Object[][] provideBinaryCollection() throws ConnectFailedException, InterruptedException {
|
||||
public Object[][] provideBinaryCollection() throws Exception, InterruptedException {
|
||||
Object[][] binaryCollection = genCollection(true,true);
|
||||
return binaryCollection;
|
||||
}
|
||||
|
||||
@DataProvider(name="BinaryIdCollection")
|
||||
public Object[][] provideBinaryIdCollection() throws ConnectFailedException, InterruptedException {
|
||||
public Object[][] provideBinaryIdCollection() throws Exception, InterruptedException {
|
||||
Object[][] binaryIdCollection = genCollection(true,false);
|
||||
return binaryIdCollection;
|
||||
}
|
||||
|
||||
@AfterSuite
|
||||
public void dropCollection(){
|
||||
// MilvusClient client = new MilvusGrpcClient();
|
||||
// List<String> collectionNames = client.listCollections().getCollectionNames();
|
||||
// collectionNames.forEach(client::dropCollection);
|
||||
System.out.println("after suite");
|
||||
}
|
||||
@AfterMethod
|
||||
public void after(){
|
||||
System.out.println("after method");
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
CommandLineParser parser = new DefaultParser();
|
||||
|
@ -156,19 +140,20 @@ public class MainClass {
|
|||
List<XmlClass> classes = new ArrayList<XmlClass>();
|
||||
|
||||
// classes.add(new XmlClass("com.TestPing"));
|
||||
// classes.add(new XmlClass("com.TestAddVectors"));
|
||||
classes.add(new XmlClass("com.TestConnect"));
|
||||
// classes.add(new XmlClass("com.TestDeleteVectors"));
|
||||
// classes.add(new XmlClass("com.TestInsertEntities"));
|
||||
// classes.add(new XmlClass("com.TestConnect"));
|
||||
// classes.add(new XmlClass("com.TestDeleteEntities"));
|
||||
// classes.add(new XmlClass("com.TestIndex"));
|
||||
// classes.add(new XmlClass("com.TestCompact"));
|
||||
// classes.add(new XmlClass("com.TestSearchVectors"));
|
||||
classes.add(new XmlClass("com.TestCollection"));
|
||||
// classes.add(new XmlClass("com.TestSearchEntities"));
|
||||
// classes.add(new XmlClass("com.TestCollection"));
|
||||
// classes.add(new XmlClass("com.TestCollectionCount"));
|
||||
// classes.add(new XmlClass("com.TestFlush"));
|
||||
// classes.add(new XmlClass("com.TestPartition"));
|
||||
// classes.add(new XmlClass("com.TestGetVectorByID"));
|
||||
// classes.add(new XmlClass("com.TestGetEntityByID"));
|
||||
// classes.add(new XmlClass("com.TestCollectionInfo"));
|
||||
// classes.add(new XmlClass("com.TestSearchByIds"));
|
||||
classes.add(new XmlClass("com.TestBeforeAndAfter"));
|
||||
|
||||
test.setXmlClasses(classes) ;
|
||||
|
||||
|
|
|
@ -1,175 +1,151 @@
|
|||
package com;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ClientSideMilvusException;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class TestCollection {
|
||||
int segmentRowCount = 5000;
|
||||
int dimension = 128;
|
||||
|
||||
@BeforeClass
|
||||
public MilvusClient setUp() throws ConnectFailedException {
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost("192.168.1.6")
|
||||
.withPort(19530)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
return client;
|
||||
}
|
||||
@AfterClass
|
||||
public void tearDown() throws ConnectFailedException {
|
||||
MilvusClient client = setUp();
|
||||
List<String> collectionNames = client.listCollections().getCollectionNames();
|
||||
// collectionNames.forEach(collection -> {client.dropCollection(collection);});
|
||||
for(String collection: collectionNames){
|
||||
System.out.print(collection+" ");
|
||||
client.dropCollection(collection);
|
||||
}
|
||||
System.out.println("After Test");
|
||||
}
|
||||
String intFieldName = Constants.intFieldName;
|
||||
String floatFieldName = Constants.floatFieldName;
|
||||
String floatVectorFieldName = Constants.floatVectorFieldName;
|
||||
Boolean autoId = true;
|
||||
|
||||
// case-01
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testCreateCollection(MilvusClient client, String collectionName){
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
|
||||
.withFields(Utils.genDefaultFields(dimension,false))
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
Response res = client.createCollection(collectionSchema);
|
||||
assert(res.ok());
|
||||
Assert.assertEquals(res.ok(), true);
|
||||
public void testCreateCollection(MilvusClient client, String collectionName) {
|
||||
// Generate connection instance
|
||||
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionName, true,false);
|
||||
client.createCollection(cm);
|
||||
Assert.assertEquals(client.hasCollection(collectionName), true);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testCreateCollectionDisconnect(MilvusClient client, String collectionName){
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
|
||||
.withFields(Utils.genDefaultFields(dimension,false))
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
Response res = client.createCollection(collectionSchema);
|
||||
assert(!res.ok());
|
||||
// case-02
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testCreateCollectionDisconnect(MilvusClient client, String collectionName) {
|
||||
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionName, true, false);
|
||||
client.createCollection(cm);
|
||||
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testCreateCollectionRepeatably(MilvusClient client, String collectionName){
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
|
||||
.withFields(Utils.genDefaultFields(dimension,false))
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
Response res = client.createCollection(collectionSchema);
|
||||
Assert.assertEquals(res.ok(), true);
|
||||
Response resNew = client.createCollection(collectionSchema);
|
||||
Assert.assertEquals(resNew.ok(), false);
|
||||
// case-03
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCreateCollectionRepeatably(MilvusClient client, String collectionName) {
|
||||
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionName, true, false);
|
||||
client.createCollection(cm);
|
||||
Assert.assertEquals(client.hasCollection(collectionName), true);
|
||||
client.createCollection(cm);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testCreateCollectionWrongParams(MilvusClient client, String collectionName){
|
||||
// case-04
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCreateCollectionWrongParams(MilvusClient client, String collectionName) {
|
||||
Integer dim = 0;
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionName)
|
||||
.withFields(Utils.genDefaultFields(dim,false))
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
Response res = client.createCollection(collectionSchema);
|
||||
System.out.println(res.toString());
|
||||
Assert.assertEquals(res.ok(), false);
|
||||
CollectionMapping cm = CollectionMapping.create(collectionName)
|
||||
.addField(intFieldName, DataType.INT64)
|
||||
.addField(floatFieldName, DataType.FLOAT)
|
||||
.addVectorField(floatVectorFieldName, DataType.VECTOR_FLOAT, dim)
|
||||
.setParamsInJson(new JsonBuilder()
|
||||
.param("segment_row_limit", Constants.segmentRowLimit)
|
||||
.param("auto_id", autoId)
|
||||
.build());
|
||||
client.createCollection(cm);
|
||||
}
|
||||
|
||||
// case-05
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testShowCollections(MilvusClient client, String collectionName){
|
||||
public void testShowCollections(MilvusClient client, String collectionName) {
|
||||
Integer collectionNum = 10;
|
||||
ListCollectionsResponse res = null;
|
||||
List<String> originCollections = new ArrayList<>();
|
||||
for (int i = 0; i < collectionNum; ++i) {
|
||||
String collectionNameNew = collectionName+"_"+Integer.toString(i);
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionNameNew)
|
||||
.withFields(Utils.genDefaultFields(dimension,false))
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
client.createCollection(collectionSchema);
|
||||
List<String> collectionNames = client.listCollections().getCollectionNames();
|
||||
Assert.assertTrue(collectionNames.contains(collectionNameNew));
|
||||
originCollections.add(collectionNameNew);
|
||||
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionNameNew, true, false);
|
||||
client.createCollection(cm);
|
||||
}
|
||||
List<String> listCollections = client.listCollections();
|
||||
originCollections.stream().forEach(listCollections::contains);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testShowCollectionsWithoutConnect(MilvusClient client, String collectionName){
|
||||
ListCollectionsResponse res = client.listCollections();
|
||||
assert(!res.getResponse().ok());
|
||||
// case-06
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testShowCollectionsWithoutConnect(MilvusClient client, String collectionName) {
|
||||
client.listCollections();
|
||||
}
|
||||
|
||||
// case-07
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDropCollection(MilvusClient client, String collectionName) throws InterruptedException {
|
||||
Response res = client.dropCollection(collectionName);
|
||||
assert(res.ok());
|
||||
Thread.currentThread().sleep(1000);
|
||||
List<String> collectionNames = client.listCollections().getCollectionNames();
|
||||
public void testDropCollection(MilvusClient client, String collectionName) {
|
||||
client.dropCollection(collectionName);
|
||||
Assert.assertEquals(client.hasCollection(collectionName), false);
|
||||
// Thread.currentThread().sleep(1000);
|
||||
List<String> collectionNames = client.listCollections();
|
||||
Assert.assertFalse(collectionNames.contains(collectionName));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-08
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDropCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
Response res = client.dropCollection(collectionName+"_");
|
||||
assert(!res.ok());
|
||||
List<String> collectionNames = client.listCollections().getCollectionNames();
|
||||
client.dropCollection(collectionName+"_");
|
||||
List<String> collectionNames = client.listCollections();
|
||||
Assert.assertTrue(collectionNames.contains(collectionName));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-09
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testDropCollectionWithoutConnect(MilvusClient client, String collectionName) {
|
||||
Response res = client.dropCollection(collectionName);
|
||||
assert(!res.ok());
|
||||
client.dropCollection(collectionName);
|
||||
}
|
||||
|
||||
// case-10
|
||||
// TODO
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDescribeCollection(MilvusClient client, String collectionName) {
|
||||
GetCollectionInfoResponse res = client.getCollectionInfo(collectionName);
|
||||
assert(res.getResponse().ok());
|
||||
CollectionMapping collectionSchema = res.getCollectionMapping().get();
|
||||
List<Map<String,Object>> fields = (List<Map<String, Object>>) collectionSchema.getFields();
|
||||
CollectionMapping info = client.getCollectionInfo(collectionName);
|
||||
List<Map<String,Object>> fields = info.getFields();
|
||||
System.out.println(fields);
|
||||
int dim = 0;
|
||||
for(Map<String,Object> field: fields){
|
||||
if ("float_vector".equals(field.get("field"))) {
|
||||
if (floatVectorFieldName.equals(field.get(Constants.fieldNameKey))) {
|
||||
JSONObject jsonObject = JSONObject.parseObject(field.get("params").toString());
|
||||
String dimParams = jsonObject.getString("params");
|
||||
dim = Utils.getParam(dimParams,"dim");
|
||||
dim = jsonObject.getIntValue("dim");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
String segmentParams = collectionSchema.getParamsInJson();
|
||||
Assert.assertEquals(dim, dimension);
|
||||
Assert.assertEquals(collectionSchema.getCollectionName(), collectionName);
|
||||
Assert.assertEquals(Utils.getParam(segmentParams,"segment_row_count"), segmentRowCount);
|
||||
JSONObject params = JSONObject.parseObject(info.getParamsInJson().toString());
|
||||
Assert.assertEquals(dim, Constants.dimension);
|
||||
Assert.assertEquals(info.getCollectionName(), collectionName);
|
||||
Assert.assertEquals(params.getIntValue("segment_row_limit"), Constants.segmentRowLimit);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-11
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testDescribeCollectionWithoutConnect(MilvusClient client, String collectionName) {
|
||||
GetCollectionInfoResponse res = client.getCollectionInfo(collectionName);
|
||||
assert(!res.getResponse().ok());
|
||||
client.getCollectionInfo(collectionName);
|
||||
}
|
||||
|
||||
// case-12
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testHasCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
HasCollectionResponse res = client.hasCollection(collectionName+"_");
|
||||
assert(res.getResponse().ok());
|
||||
Assert.assertFalse(res.hasCollection());
|
||||
String collectionNameNew = collectionName+"_";
|
||||
Assert.assertFalse(client.hasCollection(collectionNameNew));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-13
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testHasCollectionWithoutConnect(MilvusClient client, String collectionName) {
|
||||
HasCollectionResponse res = client.hasCollection(collectionName);
|
||||
assert(!res.getResponse().ok());
|
||||
client.hasCollection(collectionName);
|
||||
}
|
||||
|
||||
// case-14
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testHasCollection(MilvusClient client, String collectionName) {
|
||||
HasCollectionResponse res = client.hasCollection(collectionName);
|
||||
assert(res.getResponse().ok());
|
||||
Assert.assertTrue(res.hasCollection());
|
||||
Assert.assertTrue(client.hasCollection(collectionName));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,87 +1,72 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ClientSideMilvusException;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TestCollectionCount {
|
||||
int segmentRowCount = 5000;
|
||||
int dimension = Constants.dimension;
|
||||
int nb = Constants.nb;
|
||||
|
||||
// case-01
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCollectionCountNoVectors(MilvusClient client, String collectionName) {
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-02
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCollectionCountCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
CountEntitiesResponse res = client.countEntities(collectionName+"_");
|
||||
assert(!res.getResponse().ok());
|
||||
client.countEntities(collectionName+"_");
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-03
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testCollectionCountWithoutConnect(MilvusClient client, String collectionName) {
|
||||
CountEntitiesResponse res = client.countEntities(collectionName+"_");
|
||||
assert(!res.getResponse().ok());
|
||||
client.countEntities(collectionName);
|
||||
}
|
||||
|
||||
// case-04
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCollectionCount(MilvusClient client, String collectionName) throws InterruptedException {
|
||||
|
||||
InsertParam insertParam =
|
||||
new InsertParam.Builder(collectionName)
|
||||
.withFields(Constants.defaultEntities)
|
||||
.build();
|
||||
InsertResponse insertResponse = client.insert(insertParam);
|
||||
// Insert returns a list of entity ids that you will be using (if you did not supply the yourself) to reference the entities you just inserted
|
||||
List<Long> vectorIds = insertResponse.getEntityIds();
|
||||
// Add vectors
|
||||
Response flushResponse = client.flush(collectionName);
|
||||
Assert.assertTrue(flushResponse.ok());
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testCollectionCountBinary(MilvusClient client, String collectionName) throws InterruptedException {
|
||||
// Add vectors
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName)
|
||||
.withFields(Constants.defaultBinaryEntities)
|
||||
.build();
|
||||
client.insert(insertParam);
|
||||
public void testCollectionCount(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
// Insert returns a list of entity ids that you will be using (if you did not supply the yourself) to reference the entities you just inserted
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb);
|
||||
}
|
||||
|
||||
// case-05
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testCollectionCountBinary(MilvusClient client, String collectionName) {
|
||||
// Add vectors
|
||||
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb);
|
||||
}
|
||||
|
||||
// case-06
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCollectionCountMultiCollections(MilvusClient client, String collectionName) throws InterruptedException {
|
||||
public void testCollectionCountMultiCollections(MilvusClient client, String collectionName) {
|
||||
Integer collectionNum = 10;
|
||||
CountEntitiesResponse res;
|
||||
for (int i = 0; i < collectionNum; ++i) {
|
||||
String collectionNameNew = collectionName + "_" + i;
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionNameNew)
|
||||
.withFields(Utils.genDefaultFields(dimension,false))
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
Response cteateRes = client.createCollection(collectionSchema);
|
||||
Assert.assertEquals(cteateRes.ok(), true);
|
||||
CollectionMapping cm = Utils.genCreateCollectionMapping(collectionNameNew, true, false);
|
||||
client.createCollection(cm);
|
||||
// Add vectors
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionNameNew)
|
||||
.withFields(Constants.defaultEntities)
|
||||
.build();
|
||||
InsertResponse insertRes = client.insert(insertParam);
|
||||
Assert.assertEquals(insertRes.ok(), true);
|
||||
Response flushRes = client.flush(collectionNameNew);
|
||||
Assert.assertEquals(flushRes.ok(), true);
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionNameNew);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionNameNew);
|
||||
}
|
||||
for (int i = 0; i < collectionNum; ++i) {
|
||||
String collectionNameNew = collectionName + "_" + i;
|
||||
res = client.countEntities(collectionNameNew);
|
||||
Assert.assertEquals(res.getCollectionEntityCount(), nb);
|
||||
Assert.assertEquals(client.countEntities(collectionNameNew), nb);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
package com;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.milvus.client.*;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class TestCollectionInfo {
|
||||
int nb = Constants.nb;
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdsAfterDeleteEntities(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName)
|
||||
.withFields(Constants.defaultEntities)
|
||||
.build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
List<Long> idsBefore = resInsert.getEntityIds();
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
|
||||
client.flush(collectionName);
|
||||
Response res = client.getCollectionStats(collectionName);
|
||||
System.out.println(res.getMessage());
|
||||
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
|
||||
int rowCount = collectionInfo.getIntValue("row_count");
|
||||
assert(rowCount == nb-1);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdsAterDeleteEntitiesIndexed(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName)
|
||||
.withFields(Constants.defaultEntities)
|
||||
.build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Index index = new Index.Builder(collectionName, "float_vector")
|
||||
.withParamsInJson(Constants.indexParam).build();
|
||||
Response createIndexResponse = client.createIndex(index);
|
||||
assert(createIndexResponse.ok());
|
||||
List<Long> idsBefore = resInsert.getEntityIds();
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
|
||||
client.flush(collectionName);
|
||||
Response res = client.getCollectionStats(collectionName);
|
||||
System.out.println(res.getMessage());
|
||||
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
|
||||
int rowCount = collectionInfo.getIntValue("row_count");
|
||||
assert(rowCount == nb-1);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdsAfterDeleteEntitiesBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName)
|
||||
.withFields(Constants.defaultBinaryEntities)
|
||||
.build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
List<Long> idsBefore = resInsert.getEntityIds();
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
|
||||
client.flush(collectionName);
|
||||
Response res = client.getCollectionStats(collectionName);
|
||||
System.out.println(res.getMessage());
|
||||
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
|
||||
int rowCount = collectionInfo.getIntValue("row_count");
|
||||
assert(rowCount == nb-1);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package com;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.milvus.client.*;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class TestCollectionStats {
|
||||
int nb = Constants.nb;
|
||||
|
||||
// case-01
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdsAfterDeleteEntities(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> idsBefore = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
|
||||
client.flush(collectionName);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONObject collectionStats = JSONObject.parseObject(stats);
|
||||
int rowCount = collectionStats.getIntValue("row_count");
|
||||
assert(rowCount == nb - 1);
|
||||
}
|
||||
|
||||
// case-02
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdsAfterDeleteEntitiesIndexed(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> idsBefore = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
|
||||
client.flush(collectionName);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONObject collectionStats = JSONObject.parseObject(stats);
|
||||
int rowCount = collectionStats.getIntValue("row_count");
|
||||
assert(rowCount == nb - 1);
|
||||
}
|
||||
|
||||
// case-03
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdsAfterDeleteEntitiesBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
|
||||
List<Long> idsBefore = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(idsBefore.get(0)));
|
||||
client.flush(collectionName);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
System.out.println(stats);
|
||||
JSONObject collectionStats = JSONObject.parseObject(stats);
|
||||
int rowCount = collectionStats.getIntValue("row_count");
|
||||
assert(rowCount == nb - 1);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
package com;
|
||||
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
import java.util.List;
|
||||
|
@ -10,78 +10,156 @@ import java.util.List;
|
|||
public class TestCompact {
|
||||
int nb = Constants.nb;
|
||||
|
||||
// case-01
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCompactAfterDelete(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, ids);
|
||||
client.flush(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionName, ids);
|
||||
assert(res_delete.ok());
|
||||
client.flush(collectionName);
|
||||
CompactParam compactParam = new CompactParam.Builder(collectionName).build();
|
||||
Response res_compact = client.compact(compactParam);
|
||||
assert(res_compact.ok());
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
assert (statsResponse.ok());
|
||||
JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
|
||||
client.compact(CompactParam.create(collectionName));
|
||||
String statsResponse = client.getCollectionStats(collectionName);
|
||||
JSONObject jsonObject = JSONObject.parseObject(statsResponse);
|
||||
Assert.assertEquals(jsonObject.getIntValue("data_size"), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName), 0);
|
||||
}
|
||||
|
||||
// case-02
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testCompactAfterDeleteBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, ids);
|
||||
client.flush(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionName, ids);
|
||||
assert(res_delete.ok());
|
||||
client.flush(collectionName);
|
||||
CompactParam compactParam = new CompactParam.Builder(collectionName).build();
|
||||
Response res_compact = client.compact(compactParam);
|
||||
assert(res_compact.ok());
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
assert (statsResponse.ok());
|
||||
JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
|
||||
client.compact(CompactParam.create(collectionName));
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONObject jsonObject = JSONObject.parseObject(stats);
|
||||
Assert.assertEquals(jsonObject.getIntValue("data_size"), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-03
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCompactNoCollection(MilvusClient client, String collectionName) {
|
||||
String name = "";
|
||||
CompactParam compactParam = new CompactParam.Builder(name).build();
|
||||
Response res_compact = client.compact(compactParam);
|
||||
assert(!res_compact.ok());
|
||||
client.compact(CompactParam.create(name));
|
||||
}
|
||||
|
||||
// case-04
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCompactEmptyCollection(MilvusClient client, String collectionName) {
|
||||
CompactParam compactParam = new CompactParam.Builder(collectionName).build();
|
||||
Response res_compact = client.compact(compactParam);
|
||||
assert(res_compact.ok());
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONObject jsonObject = JSONObject.parseObject(stats);
|
||||
client.compact(CompactParam.create(collectionName));
|
||||
int data_size = jsonObject.getIntValue("data_size");
|
||||
Assert.assertEquals(0, data_size);
|
||||
}
|
||||
|
||||
// case-05
|
||||
// TODO delete not correct
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCompactThresholdLessThanDeleted(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
client.flush(collectionName);
|
||||
Response deleteRes = client.deleteEntityByID(collectionName, res.getEntityIds().subList(0, nb/4));
|
||||
assert (deleteRes.ok());
|
||||
client.flush(collectionName);
|
||||
Response resBefore = client.getCollectionStats(collectionName);
|
||||
JSONObject segmentsBefore = (JSONObject)Utils.parseJsonArray(resBefore.getMessage(), "segments").get(0);
|
||||
CompactParam compactParam = new CompactParam.Builder(collectionName).withThreshold(0.3).build();
|
||||
Response resCompact = client.compact(compactParam);
|
||||
assert (resCompact.ok());
|
||||
Response resAfter = client.getCollectionStats(collectionName);
|
||||
JSONObject segmentsAfter = (JSONObject)Utils.parseJsonArray(resAfter.getMessage(), "segments").get(0);
|
||||
Assert.assertEquals(segmentsBefore.get("data_size"), segmentsAfter.get("data_size"));
|
||||
int segmentRowLimit = nb+1000;
|
||||
String collectionNameNew = collectionName+"_";
|
||||
CollectionMapping cm = CollectionMapping.create(collectionNameNew)
|
||||
.addField(Constants.intFieldName, DataType.INT64)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT)
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, Constants.dimension)
|
||||
.setParamsInJson(new JsonBuilder()
|
||||
.param("segment_row_limit", segmentRowLimit)
|
||||
.param("auto_id", true)
|
||||
.build());
|
||||
client.createCollection(cm);
|
||||
List<Long> ids = Utils.initData(client, collectionNameNew);
|
||||
client.deleteEntityByID(collectionNameNew, ids.subList(0, nb / 4));
|
||||
client.flush(collectionNameNew);
|
||||
Assert.assertEquals(client.countEntities(collectionNameNew), nb - (nb / 4));
|
||||
// before compact
|
||||
String stats = client.getCollectionStats(collectionNameNew);
|
||||
JSONObject segmentsBefore = (JSONObject)Utils.parseJsonArray(stats, "segments").get(0);
|
||||
client.compact(CompactParam.create(collectionNameNew).setThreshold(0.9));
|
||||
// after compact
|
||||
String statsAfter = client.getCollectionStats(collectionNameNew);
|
||||
JSONObject segmentsAfter = (JSONObject)Utils.parseJsonArray(statsAfter, "segments").get(0);
|
||||
Assert.assertEquals(segmentsAfter.get("data_size"), segmentsBefore.get("data_size"));
|
||||
}
|
||||
|
||||
// case-06
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCompactInvalidThreshold(MilvusClient client, String collectionName) {
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, ids);
|
||||
client.flush(collectionName);
|
||||
client.compact(CompactParam.create(collectionName).setThreshold(-1.0));
|
||||
}
|
||||
|
||||
// // case-07, test CompactAsync callback
|
||||
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// public void testCompactAsyncAfterDelete(MilvusClient client, String collectionName) {
|
||||
// // define callback
|
||||
// FutureCallback<Response> futureCallback = new FutureCallback<Response>() {
|
||||
// @Override
|
||||
// public void onSuccess(Response compactResponse) {
|
||||
// assert(compactResponse != null);
|
||||
// assert(compactResponse.ok());
|
||||
//
|
||||
// Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// assert(statsResponse.ok());
|
||||
// JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
|
||||
// Assert.assertEquals(jsonObject.getIntValue("data_size"), 0);
|
||||
// Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public void onFailure(Throwable t) {
|
||||
// System.out.println(t.getMessage());
|
||||
// Assert.assertTrue(false);
|
||||
// }
|
||||
// };
|
||||
//
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
// InsertResponse res = client.insert(insertParam);
|
||||
// assert(res.getResponse().ok());
|
||||
// List<Long> ids = res.getEntityIds();
|
||||
// client.flush(collectionName);
|
||||
// Response res_delete = client.deleteEntityByID(collectionName, ids);
|
||||
// assert(res_delete.ok());
|
||||
// client.flush(collectionName);
|
||||
// CompactParam compactParam = new CompactParam.Builder(collectionName).build();
|
||||
//
|
||||
// // call compactAsync
|
||||
// ListenableFuture<Response> compactResponseFuture = client.compactAsync(compactParam);
|
||||
// Futures.addCallback(compactResponseFuture, futureCallback, MoreExecutors.directExecutor());
|
||||
//
|
||||
// // execute before callback
|
||||
// Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// assert(statsResponse.ok());
|
||||
// JSONObject jsonObject = JSONObject.parseObject(statsResponse.getMessage());
|
||||
// Assert.assertTrue(jsonObject.getIntValue("data_size") > 0);
|
||||
// Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
// }
|
||||
//
|
||||
// // case-08, test CompactAsync callback with invalid collection name
|
||||
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// public void testCompactAsyncNoCollection(MilvusClient client, String collectionName) {
|
||||
// // define callback
|
||||
// FutureCallback<Response> futureCallback = new FutureCallback<Response>() {
|
||||
// @Override
|
||||
// public void onSuccess(Response compactResponse) {
|
||||
// assert(compactResponse != null);
|
||||
// assert(!compactResponse.ok());
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public void onFailure(Throwable t) {
|
||||
// System.out.println(t.getMessage());
|
||||
// Assert.assertTrue(false);
|
||||
// }
|
||||
// };
|
||||
//
|
||||
// String name = "";
|
||||
// CompactParam compactParam = new CompactParam.Builder(name).build();
|
||||
//
|
||||
// // call compactAsync
|
||||
// ListenableFuture<Response> compactResponseFuture = client.compactAsync(compactParam);
|
||||
// Futures.addCallback(compactResponseFuture, futureCallback, MoreExecutors.directExecutor());
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -1,58 +1,42 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.exception.ClientSideMilvusException;
|
||||
import io.milvus.client.*;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.DataProvider;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class TestConnect {
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void testConnect(String host, int port) throws ConnectFailedException {
|
||||
public void testConnect(String host, int port) throws Exception {
|
||||
System.out.println("Host: "+host+", Port: "+port);
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
Response res = client.connect(connectParam);
|
||||
assert(res.ok());
|
||||
// assert(client.isConnected());
|
||||
MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void testConnectRepeat(String host, int port) {
|
||||
MilvusGrpcClient client = new MilvusGrpcClient();
|
||||
|
||||
Response res = null;
|
||||
try {
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
res = client.connect(connectParam);
|
||||
res = client.connect(connectParam);
|
||||
} catch (ConnectFailedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
assert (res.ok());
|
||||
// assert(client.isConnected());
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
|
||||
MilvusClient client1 = new MilvusGrpcClient(connectParam).withLogging();
|
||||
}
|
||||
|
||||
@Test(dataProvider="InvalidConnectArgs")
|
||||
public void testConnectInvalidConnect_args(String ip, int port) {
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
Response res = null;
|
||||
try {
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(ip)
|
||||
.withPort(port)
|
||||
.build();
|
||||
res = client.connect(connectParam);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
Assert.assertEquals(res, null);
|
||||
// assert(!client.isConnected());
|
||||
// TODO timeout
|
||||
@Test(dataProvider="InvalidConnectArgs", expectedExceptions = {ClientSideMilvusException.class, IllegalArgumentException.class})
|
||||
public void testConnectInvalidConnectArgs(String ip, int port) {
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(ip)
|
||||
.withPort(port)
|
||||
.withKeepAliveTimeout(1, TimeUnit.MILLISECONDS)
|
||||
.build();
|
||||
MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
|
||||
}
|
||||
|
||||
@DataProvider(name="InvalidConnectArgs")
|
||||
|
@ -62,27 +46,21 @@ public class TestConnect {
|
|||
{"1.1.1.1", port},
|
||||
{"255.255.0.0", port},
|
||||
{"1.2.2", port},
|
||||
// {"中文", port},
|
||||
// {"www.baidu.com", 100000},
|
||||
{"中文", port},
|
||||
{"www.baidu.com", 100000},
|
||||
{"127.0.0.1", 100000},
|
||||
{"www.baidu.com", 80},
|
||||
};
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testDisconnect(MilvusClient client, String collectionName){
|
||||
// assert(!client.isConnected());
|
||||
client.close();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testDisconnectRepeatably(MilvusClient client, String collectionName){
|
||||
Response res = null;
|
||||
try {
|
||||
res = client.disconnect();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
assert(!res.ok());
|
||||
// assert(!client.isConnected());
|
||||
}
|
||||
// // TODO
|
||||
// @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
// public void testDisconnectRepeatably(MilvusClient client, String collectionName){
|
||||
// client.close();
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -1,103 +1,80 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.LongStream;
|
||||
|
||||
public class TestDeleteEntities {
|
||||
|
||||
// case-01
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDeleteEntities(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionName, ids);
|
||||
assert(res_delete.ok());
|
||||
client.deleteEntityByID(collectionName, ids);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName), 0);
|
||||
}
|
||||
|
||||
// case-02
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDeleteSingleEntity(MilvusClient client, String collectionName) {
|
||||
List<List<Float>> del_vector = new ArrayList<>();
|
||||
del_vector.add(Constants.vectors.get(0));
|
||||
List<Long> del_ids = new ArrayList<>();
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
del_ids.add(ids.get(0));
|
||||
client.flush(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
|
||||
assert(res_delete.ok());
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), Constants.nb - 1);
|
||||
Assert.assertEquals(client.countEntities(collectionName), Constants.nb - 1);
|
||||
// Assert getEntityByID
|
||||
GetEntityByIDResponse res_get = client.getEntityByID(collectionName, del_ids);
|
||||
assert(res_get.getResponse().ok());
|
||||
Assert.assertEquals(res_get.getValidIds().size(), 0);
|
||||
Map<Long, Map<String, Object>> resEntity = client.getEntityByID(collectionName, Collections.singletonList(ids.get(0)));
|
||||
Assert.assertEquals(resEntity.size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-03
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDeleteEntitiesCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
client.flush(collectionName);
|
||||
List<Long> ids = res.getEntityIds();
|
||||
String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionNameNew, ids);
|
||||
assert(!res_delete.ok());
|
||||
client.deleteEntityByID(collectionNameNew, new ArrayList<Long>());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-04
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDeleteEntitiesEmptyCollection(MilvusClient client, String collectionName) {
|
||||
String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
List<Long> entityIds = LongStream.range(0, Constants.nb).boxed().collect(Collectors.toList());
|
||||
Response res_delete = client.deleteEntityByID(collectionNameNew, entityIds);
|
||||
assert(!res_delete.ok());
|
||||
client.deleteEntityByID(collectionNameNew, entityIds);
|
||||
}
|
||||
|
||||
// case-05
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDeleteEntityIdNotExisted(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = new ArrayList<Long>();
|
||||
ids.add((long)123456);
|
||||
ids.add((long)1234561);
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
List<Long> delIds = new ArrayList<Long>();
|
||||
delIds.add(123456L);
|
||||
delIds.add(1234561L);
|
||||
client.deleteEntityByID(collectionName, delIds);
|
||||
client.flush(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionName, ids);
|
||||
assert(res_delete.ok());
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), Constants.nb);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName), Constants.nb);
|
||||
}
|
||||
|
||||
// case-06
|
||||
// Below tests binary vectors
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testDeleteEntitiesBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
client.flush(collectionName);
|
||||
Response res_delete = client.deleteEntityByID(collectionName, ids);
|
||||
assert(res_delete.ok());
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, ids);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), 0);
|
||||
Assert.assertEquals(client.countEntities(collectionName), 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,29 +1,29 @@
|
|||
package com;
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import io.milvus.client.*;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TestFlush {
|
||||
int segmentRowCount = 50;
|
||||
int nb = Constants.nb;
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testFlushCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
String newCollection = "not_existed";
|
||||
Response res = client.flush(newCollection);
|
||||
assert(!res.ok());
|
||||
try {
|
||||
client.flush(newCollection);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testFlushEmptyCollection(MilvusClient client, String collectionName) {
|
||||
Response res = client.flush(collectionName);
|
||||
assert(res.ok());
|
||||
client.flush(collectionName);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
|
@ -32,64 +32,66 @@ public class TestFlush {
|
|||
int collectionNum = 10;
|
||||
for (int i = 0; i < collectionNum; i++) {
|
||||
names.add(RandomStringUtils.randomAlphabetic(10));
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(names.get(i))
|
||||
.withFields(Constants.defaultFields)
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
client.createCollection(collectionSchema);
|
||||
InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFields(Constants.defaultEntities).build();
|
||||
CollectionMapping cm = CollectionMapping
|
||||
.create(names.get(i))
|
||||
.addField("int64", DataType.INT64)
|
||||
.addField("float", DataType.FLOAT)
|
||||
.addVectorField("float_vector", DataType.VECTOR_FLOAT, Constants.dimension)
|
||||
.setParamsInJson(new JsonBuilder()
|
||||
.param("segment_row_limit", Constants.segmentRowLimit)
|
||||
.param("auto_id", true)
|
||||
.build());
|
||||
client.createCollection(cm);
|
||||
InsertParam insertParam = Utils.genInsertParam(names.get(i));
|
||||
client.insert(insertParam);
|
||||
System.out.println("Table " + names.get(i) + " created.");
|
||||
}
|
||||
Response res = client.flush(names);
|
||||
assert(res.ok());
|
||||
client.flush(names);
|
||||
for (int i = 0; i < collectionNum; i++) {
|
||||
// check row count
|
||||
Assert.assertEquals(client.countEntities(names.get(i)).getCollectionEntityCount(), nb);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testAddCollectionsFlushAsync(MilvusClient client, String collectionName) throws ExecutionException, InterruptedException {
|
||||
List<String> names = new ArrayList<>();
|
||||
for (int i = 0; i < 100; i++) {
|
||||
names.add(RandomStringUtils.randomAlphabetic(10));
|
||||
CollectionMapping collectionSchema = new CollectionMapping.Builder(names.get(i))
|
||||
.withFields(Constants.defaultFields)
|
||||
.withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
.build();
|
||||
client.createCollection(collectionSchema);
|
||||
InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFields(Constants.defaultEntities).build();
|
||||
client.insert(insertParam);
|
||||
System.out.println("Collection " + names.get(i) + " created.");
|
||||
}
|
||||
ListenableFuture<Response> flushResponseFuture = client.flushAsync(names);
|
||||
flushResponseFuture.get();
|
||||
for (int i = 0; i < 100; i++) {
|
||||
// check row count
|
||||
Assert.assertEquals(client.countEntities(names.get(i)).getCollectionEntityCount(), nb);
|
||||
Assert.assertEquals(client.countEntities(names.get(i)), nb);
|
||||
}
|
||||
}
|
||||
|
||||
// @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
// public void testAddCollectionsFlushAsync(MilvusClient client, String collectionName) throws ExecutionException, InterruptedException {
|
||||
// List<String> names = new ArrayList<>();
|
||||
// for (int i = 0; i < 100; i++) {
|
||||
// names.add(RandomStringUtils.randomAlphabetic(10));
|
||||
// CollectionMapping collectionSchema = new CollectionMapping.Builder(names.get(i))
|
||||
// .withFields(Constants.defaultFields)
|
||||
// .withParamsInJson(String.format("{\"segment_row_count\": %s}",segmentRowCount))
|
||||
// .build();
|
||||
// client.createCollection(collectionSchema);
|
||||
// InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFields(Constants.defaultEntities).build();
|
||||
// client.insert(insertParam);
|
||||
// System.out.println("Collection " + names.get(i) + " created.");
|
||||
// }
|
||||
// ListenableFuture<Response> flushResponseFuture = client.flushAsync(names);
|
||||
// flushResponseFuture.get();
|
||||
// for (int i = 0; i < 100; i++) {
|
||||
// // check row count
|
||||
// Assert.assertEquals(client.countEntities(names.get(i)).getCollectionEntityCount(), nb);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testAddFlushMultipleTimes(MilvusClient client, String collectionName) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
client.insert(insertParam);
|
||||
Response res = client.flush(collectionName);
|
||||
assert(res.ok());
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb * (i+1));
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb * (i+1));
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testAddFlushMultipleTimesBinary(MilvusClient client, String collectionName) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
client.insert(insertParam);
|
||||
Response res = client.flush(collectionName);
|
||||
assert(res.ok());
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb * (i+1));
|
||||
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb * (i+1));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,12 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.*;
|
||||
|
||||
public class TestGetEntityByID {
|
||||
public List<Long> get_ids = Utils.toListIds(1111);
|
||||
|
@ -13,77 +14,77 @@ public class TestGetEntityByID {
|
|||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntitiesByIdValid(MilvusClient client, String collectionName) {
|
||||
int get_length = 100;
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
List<Long> ids = resInsert.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, get_length));
|
||||
assert (res.getResponse().ok());
|
||||
// assert (res.getValidIds(), ids.subList(0, get_length));
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, ids.subList(0, get_length));
|
||||
for (int i = 0; i < get_length; i++) {
|
||||
List<Map<String,Object>> fieldsMap = res.getFieldsMap();
|
||||
assert (fieldsMap.get(i).get("float_vector").equals(Constants.vectors.get(i)));
|
||||
Map<String,Object> fieldsMap = resEntities.get(ids.get(i));
|
||||
assert (fieldsMap.get("float_vector").equals(Constants.vectors.get(i)));
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityByIdAfterDelete(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
List<Long> ids = resInsert.getEntityIds();
|
||||
Response res_delete = client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
|
||||
assert(res_delete.ok());
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(1)));
|
||||
client.flush(collectionName);
|
||||
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
|
||||
assert (res.getResponse().ok());
|
||||
assert (res.getFieldsMap().size() == 0);
|
||||
List<Long> getIds = ids.subList(0,2);
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, getIds);
|
||||
Assert.assertEquals(resEntities.size(), getIds.size()-1);
|
||||
Assert.assertEquals(resEntities.get(getIds.get(0)).get(Constants.floatVectorFieldName), Constants.vectors.get(0));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testGetEntityByIdCollectionNameNotExisted(MilvusClient client, String collectionName) {
|
||||
String newCollection = "not_existed";
|
||||
GetEntityByIDResponse res = client.getEntityByID(newCollection, get_ids);
|
||||
assert(!res.getResponse().ok());
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(newCollection, get_ids);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testGetVectorIdNotExisted(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
GetEntityByIDResponse res = client.getEntityByID(collectionName, get_ids);
|
||||
assert (res.getFieldsMap().size() == 0);
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, get_ids);
|
||||
Assert.assertEquals(resEntities.size(), 0);
|
||||
}
|
||||
|
||||
// Binary tests
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityByIdValidBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
List<Long> ids = resInsert.getEntityIds();
|
||||
int get_length = 20;
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
|
||||
assert res.getFieldsMap().get(0).get(Constants.binaryFieldName).equals(Constants.vectorsBinary.get(0).rewind());
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, ids.subList(0, get_length));
|
||||
for (int i = 0; i < get_length; i++) {
|
||||
assert (resEntities.get(ids.get(i)).get(Constants.binaryVectorFieldName).equals(vectors.get(i)));
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityByIdAfterDeleteBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
InsertResponse resInsert = client.insert(insertParam);
|
||||
List<Long> ids = resInsert.getEntityIds();
|
||||
Response res_delete = client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
|
||||
assert(res_delete.ok());
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
client.deleteEntityByID(collectionName, Collections.singletonList(ids.get(0)));
|
||||
client.flush(collectionName);
|
||||
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
|
||||
assert (res.getFieldsMap().size() == 0);
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, ids.subList(0, 1));
|
||||
Assert.assertEquals(resEntities.size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testGetEntityIdNotExistedBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
GetEntityByIDResponse res = client.getEntityByID(collectionName, get_ids);
|
||||
assert (res.getFieldsMap().size() == 0);
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
Map<Long, Map<String, Object>> resEntities = client.getEntityByID(collectionName, get_ids);
|
||||
Assert.assertEquals(resEntities.size(), 0);
|
||||
}
|
||||
}
|
|
@ -3,148 +3,196 @@ package com;
|
|||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ClientSideMilvusException;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TestIndex {
|
||||
|
||||
// case-01
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCreateIndex(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(res_create.ok());
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// TODO: should check getCollectionStats
|
||||
if(statsResponse.ok()) {
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
|
||||
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
|
||||
Assert.assertEquals(file.get("index_type"), Constants.indexType));
|
||||
}
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
|
||||
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
|
||||
Assert.assertEquals(file.get("index_type"), Constants.indexType.toString()));
|
||||
}
|
||||
|
||||
// case-02
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testCreateIndexBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder(collectionName, Constants.binaryFieldName).withParamsInJson(Constants.binaryIndexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(res_create.ok());
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// TODO: should check getCollectionStats
|
||||
if(statsResponse.ok()) {
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
|
||||
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
|
||||
Assert.assertEquals(file.get("index_type"), Constants.defaultBinaryIndexType));
|
||||
}
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.binaryVectorFieldName)
|
||||
.setIndexType(Constants.defaultBinaryIndexType)
|
||||
.setMetricType(MetricType.JACCARD)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
|
||||
filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
|
||||
Assert.assertEquals(file.get("index_type"), Constants.defaultBinaryIndexType.toString()));
|
||||
}
|
||||
|
||||
// case-03
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCreateIndexRepeatably(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(res_create.ok());
|
||||
Response res_create_2 = client.createIndex(index);
|
||||
assert(res_create_2.ok());
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
client.createIndex(index);
|
||||
}
|
||||
|
||||
// case-04
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCreateIndexWithNoVector(MilvusClient client, String collectionName) {
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(res_create.ok());
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-05
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCreateIndexTableNotExisted(MilvusClient client, String collectionName) {
|
||||
String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
Index index = new Index.Builder(collectionNameNew, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(!res_create.ok());
|
||||
Index index = Index
|
||||
.create(collectionNameNew, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-06
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testCreateIndexWithoutConnect(MilvusClient client, String collectionName) {
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(!res_create.ok());
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-07
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCreateIndexInvalidNList(MilvusClient client, String collectionName) {
|
||||
int n_list = 0;
|
||||
String indexParamNew = Utils.setIndexParam(Constants.indexType, "L2", n_list);
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(indexParamNew).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(!res_create.ok());
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", n_list).build());
|
||||
client.createIndex(index);
|
||||
}
|
||||
|
||||
// # 3407
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// #3407
|
||||
// case-08
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCreateIndexInvalidMetricTypeBinary(MilvusClient client, String collectionName) {
|
||||
String metric_type = "L2";
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
client.insert(insertParam);
|
||||
String indexParamNew = Utils.setIndexParam(Constants.defaultBinaryIndexType, metric_type, Constants.n_list);
|
||||
Index createIndexParam = new Index.Builder(collectionName, Constants.binaryFieldName).withParamsInJson(indexParamNew).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert (!res_create.ok());
|
||||
MetricType metric_type = MetricType.L2;
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.binaryVectorFieldName)
|
||||
.setIndexType(IndexType.BIN_IVF_FLAT)
|
||||
.setMetricType(metric_type)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
}
|
||||
|
||||
// #3408
|
||||
// case-09
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDropIndex(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(res_create.ok());
|
||||
Response res_drop = client.dropIndex(collectionName, Constants.floatFieldName);
|
||||
assert(res_drop.ok());
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// TODO: should check getCollectionStats
|
||||
if(statsResponse.ok()) {
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
|
||||
filesJsonArray.stream().map(item -> (JSONObject) item).forEach(file->{
|
||||
assert (!file.containsKey("index_type"));
|
||||
});
|
||||
List<Long> ids = Utils.initData(client, collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(Constants.indexType)
|
||||
.setMetricType(Constants.defaultMetricType)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
client.dropIndex(collectionName, Constants.floatVectorFieldName);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
|
||||
for (Object item : filesJsonArray) {
|
||||
JSONObject file = (JSONObject) item;
|
||||
Assert.assertFalse(file.containsKey("index_type"));
|
||||
}
|
||||
}
|
||||
|
||||
// case-10
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testDropIndexBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder(collectionName, Constants.binaryFieldName).withParamsInJson(Constants.binaryIndexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
assert(res_create.ok());
|
||||
Response res_drop = client.dropIndex(collectionName, Constants.binaryFieldName);
|
||||
assert(res_drop.ok());
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// TODO: should check getCollectionStats
|
||||
if(statsResponse.ok()) {
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
|
||||
filesJsonArray.stream().map(item -> (JSONObject) item).forEach(file->{
|
||||
assert (!file.containsKey("index_type"));
|
||||
});
|
||||
List<Long> ids = Utils.initBinaryData(client, collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.binaryVectorFieldName)
|
||||
.setIndexType(Constants.defaultBinaryIndexType)
|
||||
.setMetricType(Constants.defaultBinaryMetricType)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
client.dropIndex(collectionName, Constants.binaryVectorFieldName);
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONArray filesJsonArray = Utils.parseJsonArray(stats, "files");
|
||||
for (Object item : filesJsonArray) {
|
||||
JSONObject file = (JSONObject) item;
|
||||
Assert.assertFalse(file.containsKey("index_type"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDropIndexTableNotExisted(MilvusClient client, String collectionName) {
|
||||
// case-11
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDropIndexCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
Response res_drop = client.dropIndex(collectionNameNew, Constants.floatFieldName);
|
||||
assert(!res_drop.ok());
|
||||
client.dropIndex(collectionNameNew, Constants.floatVectorFieldName);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-12
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testDropIndexWithoutConnect(MilvusClient client, String collectionName) {
|
||||
Response res_drop = client.dropIndex(collectionName, Constants.floatFieldName);
|
||||
assert(!res_drop.ok());
|
||||
client.dropIndex(collectionName, Constants.floatVectorFieldName);
|
||||
}
|
||||
//
|
||||
// // case-13
|
||||
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// public void testAsyncIndex(MilvusClient client, String collectionName) {
|
||||
// Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
// ListenableFuture<Response> createIndexResFuture = client.createIndexAsync(index);
|
||||
// Futures.addCallback(
|
||||
// createIndexResFuture, new FutureCallback<Response>() {
|
||||
// @Override
|
||||
// public void onSuccess(Response createIndexResponse) {
|
||||
// Assert.assertNotNull(createIndexResponse);
|
||||
// Assert.assertTrue(createIndexResponse.ok());
|
||||
// Response statsResponse = client.getCollectionStats(collectionName);
|
||||
// if(statsResponse.ok()) {
|
||||
// JSONArray filesJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "files");
|
||||
// filesJsonArray.stream().map(item-> (JSONObject)item).filter(item->item.containsKey("index_type")).forEach(file->
|
||||
// Assert.assertEquals(file.get("index_type"), Constants.indexType));
|
||||
// }
|
||||
// }
|
||||
// @Override
|
||||
// public void onFailure(Throwable t) {
|
||||
// System.out.println(t.getMessage());
|
||||
// }
|
||||
// }, MoreExecutors.directExecutor()
|
||||
// );
|
||||
// }
|
||||
|
||||
}
|
||||
|
|
|
@ -3,7 +3,8 @@ package com;
|
|||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.milvus.client.*;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import io.milvus.client.exception.ClientSideMilvusException;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
@ -12,175 +13,254 @@ import java.util.*;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.LongStream;
|
||||
|
||||
|
||||
public class TestInsertEntities {
|
||||
int dimension = Constants.dimension;
|
||||
String tag = "tag";
|
||||
int nb = Constants.nb;
|
||||
Map<String, List> entities = Constants.defaultEntities;
|
||||
Map<String, List> binaryEntities = Constants.defaultBinaryEntities;
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-01
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertEntitiesCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
String collectionNameNew = collectionName + "_";
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionNameNew)
|
||||
.withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionNameNew);
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// case-02
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class, expectedExceptions = ClientSideMilvusException.class)
|
||||
public void testInsertEntitiesWithoutConnect(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
// case-03
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testInsertEntities(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Response res_flush = client.flush(collectionName);
|
||||
assert(res_flush.ok());
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
|
||||
List<Long> vectorIds = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
Assert.assertEquals(vectorIds.size(), nb);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb);
|
||||
}
|
||||
|
||||
// case-04
|
||||
@Test(dataProvider = "IdCollection", dataProviderClass = MainClass.class)
|
||||
public void testInsertEntityWithIds(MilvusClient client, String collectionName) {
|
||||
// Add vectors with ids
|
||||
List<Long> entityIds = LongStream.range(0, nb).boxed().collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName)
|
||||
.withFields(Constants.defaultEntities)
|
||||
.withEntityIds(entityIds)
|
||||
.build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Response res_flush = client.flush(collectionName);
|
||||
assert(res_flush.ok());
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
|
||||
.setEntityIds(entityIds);
|
||||
List<Long> vectorIds = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Assert ids and collection row count
|
||||
Assert.assertEquals(res.getEntityIds(), entityIds);
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
Assert.assertEquals(vectorIds, entityIds);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "IdCollection", dataProviderClass = MainClass.class)
|
||||
// case-05
|
||||
@Test(dataProvider = "IdCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertEntityWithInvalidIds(MilvusClient client, String collectionName) {
|
||||
// Add vectors with ids
|
||||
List<Long> entityIds = LongStream.range(0, nb+1).boxed().collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).withEntityIds(entityIds).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
|
||||
.setEntityIds(entityIds);
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-06
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertEntityWithInvalidDimension(MilvusClient client, String collectionName) {
|
||||
List<List<Float>> vectors = Utils.genVectors(nb, dimension+1, true);
|
||||
List<Map<String,Object>> entities = Utils.genDefaultEntities(dimension+1,nb,vectors);
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(entities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
Map<String, List> entities = Utils.genDefaultEntities(nb,vectors);
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-07
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertEntityWithInvalidVectors(MilvusClient client, String collectionName) {
|
||||
List<Map<String,Object>> invalidEntities = Utils.genDefaultEntities(dimension,nb,new ArrayList<>());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(invalidEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
Map<String, List> entities = Utils.genDefaultEntities(nb,new ArrayList<>());
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
// ----------------------------- partition cases in Insert ---------------------------------
|
||||
// Add vectors into collection with given tag
|
||||
// case-08: Add vectors into collection with given tag
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testInsertEntityPartition(MilvusClient client, String collectionName) {
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert(createpResponse.ok());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).withPartitionTag(tag).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Response res_flush = client.flush(collectionName);
|
||||
assert(res_flush.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
|
||||
.setPartitionTag(tag);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
if(statsResponse.ok()) {
|
||||
JSONArray partitionsJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "partitions");
|
||||
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
|
||||
Assert.assertEquals(obj.get("row_count"), nb);
|
||||
Assert.assertEquals(obj.get("tag"), tag);
|
||||
});
|
||||
}
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONArray partitionsJsonArray = Utils.parseJsonArray(stats, "partitions");
|
||||
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
|
||||
Assert.assertEquals(obj.get("row_count"), nb);
|
||||
Assert.assertEquals(obj.get("tag"), tag);
|
||||
});
|
||||
}
|
||||
|
||||
// Add vectors into collection, which tag not existed
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// case-09: Add vectors into collection, which tag not existed
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertEntityPartitionTagNotExisted(MilvusClient client, String collectionName) {
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert(createpResponse.ok());
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).withPartitionTag(tag).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName))
|
||||
.setPartitionTag(tag);
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
// Binary tests
|
||||
// case-10: Binary tests
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testInsertEntityPartitionABinary(MilvusClient client, String collectionName) {
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).withPartitionTag(tag).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Response res_flush = client.flush(collectionName);
|
||||
assert(res_flush.ok());
|
||||
// Assert collection row count
|
||||
Response statsResponse = client.getCollectionStats(collectionName);
|
||||
if(statsResponse.ok()) {
|
||||
JSONArray partitionsJsonArray = Utils.parseJsonArray(statsResponse.getMessage(), "partitions");
|
||||
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
|
||||
Assert.assertEquals(obj.get("tag"), tag);
|
||||
Assert.assertEquals(obj.get("row_count"), nb);
|
||||
});
|
||||
client.createPartition(collectionName, tag);
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors)
|
||||
.setPartitionTag(tag);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
String stats = client.getCollectionStats(collectionName);
|
||||
JSONArray partitionsJsonArray = Utils.parseJsonArray(stats, "partitions");
|
||||
partitionsJsonArray.stream().map(item -> (JSONObject) item).filter(item->item.containsValue(tag)).forEach(obj -> {
|
||||
Assert.assertEquals(obj.get("tag"), tag);
|
||||
Assert.assertEquals(obj.get("row_count"), nb);
|
||||
});
|
||||
}
|
||||
|
||||
// case-11
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testInsertEntityBinary(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Response res_flush = client.flush(collectionName);
|
||||
assert(res_flush.ok());
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
List<Long> vectorIds = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
Assert.assertEquals(vectorIds.size(), nb);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb);
|
||||
}
|
||||
|
||||
// case-12
|
||||
@Test(dataProvider = "BinaryIdCollection", dataProviderClass = MainClass.class)
|
||||
public void testInsertBinaryEntityWithIds(MilvusClient client, String collectionName) {
|
||||
// Add vectors with ids
|
||||
List<Long> entityIds = LongStream.range(0, nb).boxed().collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).withEntityIds(entityIds).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Response res_flush = client.flush(collectionName);
|
||||
assert(res_flush.ok());
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors)
|
||||
.setEntityIds(entityIds);
|
||||
List<Long> vectorIds = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Assert collection row count
|
||||
Assert.assertEquals(entityIds, res.getEntityIds());
|
||||
Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
Assert.assertEquals(entityIds, vectorIds);
|
||||
Assert.assertEquals(client.countEntities(collectionName), nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// case-13
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertBinaryEntityWithInvalidIds(MilvusClient client, String collectionName) {
|
||||
// Add vectors with ids
|
||||
List<Long> invalidEntityIds = LongStream.range(0, nb+1).boxed().collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).withEntityIds(invalidEntityIds).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, binaryEntities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, binaryEntities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, binaryEntities.get(Constants.binaryVectorFieldName))
|
||||
.setEntityIds(invalidEntityIds);
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// case-14
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testInsertBinaryEntityWithInvalidDimension(MilvusClient client, String collectionName) {
|
||||
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension-1);
|
||||
List<Map<String,Object>> binaryEntities = Utils.genDefaultBinaryEntities(dimension-1,nb,vectorsBinary);
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(binaryEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(nb, dimension-1);
|
||||
Map<String, List> entities = Utils.genDefaultBinaryEntities(nb,vectors);
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, entities.get(Constants.binaryVectorFieldName));
|
||||
client.insert(insertParam);
|
||||
}
|
||||
|
||||
// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
// public void testAsyncInsert(MilvusClient client, String collectionName) {
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
// ListenableFuture<InsertResponse> insertResFuture = client.insertAsync(insertParam);
|
||||
// Futures.addCallback(
|
||||
// insertResFuture, new FutureCallback<InsertResponse>() {
|
||||
// @Override
|
||||
// public void onSuccess(InsertResponse insertResponse) {
|
||||
// Assert.assertNotNull(insertResponse);
|
||||
// Assert.assertTrue(insertResponse.ok());
|
||||
// Assert.assertEquals(client.countEntities(collectionName).getCollectionEntityCount(), nb);
|
||||
// }
|
||||
// @Override
|
||||
// public void onFailure(Throwable t) {
|
||||
// System.out.println(t.getMessage());
|
||||
// }
|
||||
// }, MoreExecutors.directExecutor()
|
||||
// );
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//package com;
|
||||
package com;//package com;
|
||||
//
|
||||
//import io.milvus.client.*;
|
||||
//import org.apache.commons.lang3.RandomStringUtils;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//package com;
|
||||
package com;//package com;
|
||||
//
|
||||
//import io.milvus.client.*;
|
||||
//import org.apache.commons.cli.*;
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
import java.util.List;
|
||||
|
||||
public class TestPartition {
|
||||
int dimension = 128;
|
||||
|
||||
// ----------------------------- create partition cases in ---------------------------------
|
||||
|
||||
|
@ -15,22 +15,19 @@ public class TestPartition {
|
|||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testCreatePartition(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
Assert.assertEquals(client.hasPartition(collectionName, tag), true);
|
||||
// show partitions
|
||||
List<String> partitions = client.listPartitions(collectionName).getPartitionList();
|
||||
System.out.println(partitions);
|
||||
List<String> partitions = client.listPartitions(collectionName);
|
||||
Assert.assertTrue(partitions.contains(tag));
|
||||
}
|
||||
|
||||
// create partition, tag name existed
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testCreatePartitionTagNameExisted(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
Response createpResponseNew = client.createPartition(collectionName, tag);
|
||||
assert (!createpResponseNew.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
client.createPartition(collectionName, tag);
|
||||
}
|
||||
|
||||
// ----------------------------- has partition cases in ---------------------------------
|
||||
|
@ -38,23 +35,17 @@ public class TestPartition {
|
|||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testHasPartitionTagNameNotExisted(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
String tagNew = RandomStringUtils.randomAlphabetic(10);
|
||||
HasPartitionResponse haspResponse = client.hasPartition(collectionName, tagNew);
|
||||
assert (haspResponse.ok());
|
||||
Assert.assertFalse(haspResponse.hasPartition());
|
||||
Assert.assertFalse(client.hasPartition(collectionName, tagNew));
|
||||
}
|
||||
|
||||
// has partition, tag name existed
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testHasPartitionTagNameExisted(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
HasPartitionResponse haspResponse = client.hasPartition(collectionName, tag);
|
||||
assert (haspResponse.ok());
|
||||
Assert.assertTrue(haspResponse.hasPartition());
|
||||
client.createPartition(collectionName, tag);
|
||||
Assert.assertTrue(client.hasPartition(collectionName, tag));
|
||||
}
|
||||
|
||||
// ----------------------------- drop partition cases in ---------------------------------
|
||||
|
@ -63,62 +54,47 @@ public class TestPartition {
|
|||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testDropPartition(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponseNew = client.createPartition(collectionName, tag);
|
||||
assert (createpResponseNew.ok());
|
||||
Response response = client.dropPartition(collectionName, tag);
|
||||
assert (response.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
client.dropPartition(collectionName, tag);
|
||||
// show partitions
|
||||
System.out.println(client.listPartitions(collectionName).getPartitionList());
|
||||
int length = client.listPartitions(collectionName).getPartitionList().size();
|
||||
System.out.println(client.listPartitions(collectionName));
|
||||
int length = client.listPartitions(collectionName).size();
|
||||
// _default
|
||||
Assert.assertEquals(length, 1);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDropPartitionDefault(MilvusClient client, String collectionName) {
|
||||
String tag = "_default";
|
||||
Response createpResponseNew = client.createPartition(collectionName, tag);
|
||||
assert (!createpResponseNew.ok());
|
||||
// show partitions
|
||||
// System.out.println(client.listPartitions(collectionName).getPartitionList());
|
||||
// int length = client.listPartitions(collectionName).getPartitionList().size();
|
||||
// // _default
|
||||
// Assert.assertEquals(length, 1);
|
||||
client.createPartition(collectionName, tag);
|
||||
}
|
||||
|
||||
// drop a partition repeat created before, drop by partition name
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDropPartitionRepeat(MilvusClient client, String collectionName) throws InterruptedException {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
Response response = client.dropPartition(collectionName, tag);
|
||||
assert (response.ok());
|
||||
Thread.currentThread().sleep(2000);
|
||||
Response newResponse = client.dropPartition(collectionName, tag);
|
||||
assert (!newResponse.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
client.dropPartition(collectionName, tag);
|
||||
Thread.sleep(2000);
|
||||
client.dropPartition(collectionName, tag);
|
||||
}
|
||||
|
||||
// drop a partition not created before
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDropPartitionNotExisted(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
String tagNew = RandomStringUtils.randomAlphabetic(10);
|
||||
Response response = client.dropPartition(collectionName, tagNew);
|
||||
assert(!response.ok());
|
||||
client.dropPartition(collectionName, tagNew);
|
||||
}
|
||||
|
||||
// drop a partition not created before
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testDropPartitionTagNotExisted(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert(createpResponse.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
String newTag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response response = client.dropPartition(collectionName, newTag);
|
||||
assert(!response.ok());
|
||||
client.dropPartition(collectionName, newTag);
|
||||
}
|
||||
|
||||
// ----------------------------- show partitions cases in ---------------------------------
|
||||
|
@ -127,27 +103,22 @@ public class TestPartition {
|
|||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testShowPartitions(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
ListPartitionsResponse response = client.listPartitions(collectionName);
|
||||
assert (response.getResponse().ok());
|
||||
Assert.assertTrue(response.getPartitionList().contains(tag));
|
||||
client.createPartition(collectionName, tag);
|
||||
List<String> partitions = client.listPartitions(collectionName);
|
||||
Assert.assertTrue(partitions.contains(tag));
|
||||
}
|
||||
|
||||
// create multi partition, then show partitions
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testShowPartitionsMulti(MilvusClient client, String collectionName) {
|
||||
String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
Response createpResponse = client.createPartition(collectionName, tag);
|
||||
assert (createpResponse.ok());
|
||||
client.createPartition(collectionName, tag);
|
||||
String tagNew = RandomStringUtils.randomAlphabetic(10);
|
||||
Response newCreatepResponse = client.createPartition(collectionName, tagNew);
|
||||
assert (newCreatepResponse.ok());
|
||||
ListPartitionsResponse response = client.listPartitions(collectionName);
|
||||
assert (response.getResponse().ok());
|
||||
System.out.println(response.getPartitionList());
|
||||
Assert.assertTrue(response.getPartitionList().contains(tag));
|
||||
Assert.assertTrue(response.getPartitionList().contains(tagNew));
|
||||
client.createPartition(collectionName, tagNew);
|
||||
List<String> partitions = client.listPartitions(collectionName);
|
||||
System.out.println(partitions);
|
||||
Assert.assertTrue(partitions.contains(tag));
|
||||
Assert.assertTrue(partitions.contains(tagNew));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class TestPing {
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void testServerStatus(String host, int port) throws ConnectFailedException {
|
||||
System.out.println("Host: "+host+", Port: "+port);
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
Response res = client.getServerStatus();
|
||||
assert (res.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void testServerStatusWithoutConnected(MilvusClient client, String collectionName) throws ConnectFailedException {
|
||||
Response res = client.getServerStatus();
|
||||
assert (!res.ok());
|
||||
}
|
||||
}
|
||||
//package com1;
|
||||
//
|
||||
//import io.milvus.client.*;
|
||||
//import org.testng.annotations.Test;
|
||||
//
|
||||
//public class TestPing {
|
||||
// @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
// public void testServerStatus(String host, int port) throws ConnectFailedException {
|
||||
// System.out.println("Host: "+host+", Port: "+port);
|
||||
// MilvusClient client = new MilvusGrpcClient();
|
||||
// ConnectParam connectParam = new ConnectParam.Builder()
|
||||
// .withHost(host)
|
||||
// .withPort(port)
|
||||
// .build();
|
||||
// client.connect(connectParam);
|
||||
// Response res = client.getServerStatus();
|
||||
// assert (res.ok());
|
||||
// }
|
||||
//
|
||||
// @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
// public void testServerStatusWithoutConnected(MilvusClient client, String collectionName) throws ConnectFailedException {
|
||||
// Response res = client.getServerStatus();
|
||||
// assert (!res.ok());
|
||||
// }
|
||||
//}
|
|
@ -1,4 +1,4 @@
|
|||
//package com;
|
||||
package com;//package com;
|
||||
//
|
||||
//import io.milvus.client.*;
|
||||
//import org.apache.commons.lang3.RandomStringUtils;
|
||||
|
|
|
@ -1,52 +1,53 @@
|
|||
package com;
|
||||
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.util.concurrent.FutureCallback;
|
||||
import com.google.common.util.concurrent.Futures;
|
||||
import com.google.common.util.concurrent.ListenableFuture;
|
||||
import com.google.common.util.concurrent.MoreExecutors;
|
||||
import io.milvus.client.*;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import io.milvus.client.exception.InvalidDsl;
|
||||
import io.milvus.client.exception.ServerSideMilvusException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class TestSearchEntities {
|
||||
|
||||
int small_nb = 10;
|
||||
int n_probe = 20;
|
||||
int top_k = 10;
|
||||
int nq = 5;
|
||||
int top_k = Constants.topk;
|
||||
int nq = Constants.nq;
|
||||
|
||||
List<List<Float>> queryVectors = Constants.vectors.subList(0, nq);
|
||||
List<ByteBuffer> queryVectorsBinary = Constants.vectorsBinary.subList(0, nq);
|
||||
public String dsl = Constants.searchParam;
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public String floatDsl = Constants.searchParam;
|
||||
public String binaryDsl = Constants.binarySearchParam;
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchCollectionNotExisted(MilvusClient client, String collectionName) {
|
||||
String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
SearchParam searchParam = SearchParam.create(collectionNameNew).setDsl(floatDsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchCollectionEmpty(MilvusClient client, String collectionName) {
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
Assert.assertEquals(res_search.getResultIdsList().size(), 0);
|
||||
}
|
||||
|
||||
// # 3429
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchCollection(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
Assert.assertEquals(res_search.getResultIdsList().size(), Constants.nq);
|
||||
Assert.assertEquals(res_search.getResultDistancesList().size(), Constants.nq);
|
||||
Assert.assertEquals(res_search.getResultIdsList().get(0).size(), Constants.topk);
|
||||
|
@ -56,13 +57,11 @@ public class TestSearchEntities {
|
|||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchDistance(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
for (int i = 0; i < Constants.nq; i++) {
|
||||
double distance = res_search.getResultDistancesList().get(i).get(0);
|
||||
Assert.assertEquals(distance, 0.0, Constants.epsilon);
|
||||
|
@ -71,14 +70,12 @@ public class TestSearchEntities {
|
|||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchDistanceIP(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
String dsl = Utils.setSearchParam("IP", Constants.vectors.subList(0, nq), top_k, n_probe);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
String dsl = Utils.setSearchParam(MetricType.IP, queryVectors, top_k, n_probe);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
for (int i = 0; i < Constants.nq; i++) {
|
||||
double distance = res_search.getResultDistancesList().get(i).get(0);
|
||||
Assert.assertEquals(distance, 1.0, Constants.epsilon);
|
||||
|
@ -91,174 +88,359 @@ public class TestSearchEntities {
|
|||
List<String> queryTags = new ArrayList<>();
|
||||
queryTags.add(tag);
|
||||
client.createPartition(collectionName, tag);
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl).setPartitionTags(queryTags);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
Assert.assertEquals(res_search.getResultDistancesList().size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchPartitionNotExited(MilvusClient client, String collectionName) {
|
||||
public void testSearchPartitionNotExisted(MilvusClient client, String collectionName) {
|
||||
String tag = Utils.genUniqueStr("tag");
|
||||
String tagNew = Utils.genUniqueStr("tagNew");
|
||||
List<String> queryTags = new ArrayList<>();
|
||||
queryTags.add(tag);
|
||||
queryTags.add(tagNew);
|
||||
client.createPartition(collectionName, tag);
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert (res.getResponse().ok());
|
||||
Map<String, List> entities = Constants.defaultEntities;
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl).setPartitionTags(queryTags);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
Assert.assertEquals(res_search.getResultDistancesList().size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchInvalidNProbe(MilvusClient client, String collectionName) {
|
||||
int n_probe_new = -1;
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
Response res_create = client.createIndex(index);
|
||||
String dsl = Utils.setSearchParam(Constants.defaultMetricType, Constants.vectors.subList(0, nq), top_k, n_probe_new);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert(!res_search.getResponse().ok());
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.floatVectorFieldName)
|
||||
.setIndexType(IndexType.IVF_SQ8)
|
||||
.setMetricType(MetricType.L2)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
String dsl = Utils.setSearchParam(Constants.defaultMetricType, queryVectors, top_k, n_probe_new);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchCountLessThanTopK(MilvusClient client, String collectionName) {
|
||||
int top_k_new = 100;
|
||||
int nb = 50;
|
||||
List<Map<String,Object>> entities = Utils.genDefaultEntities(Constants.dimension, nb, Utils.genVectors(nb, Constants.dimension, false));
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(entities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
Map<String,List> entities = Utils.genDefaultEntities(nb, Utils.genVectors(nb, Constants.dimension, false));
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
String dsl = Utils.setSearchParam(Constants.defaultMetricType, Constants.vectors.subList(0, nq), top_k_new, n_probe);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert(res_search.getResponse().ok());
|
||||
String dsl = Utils.setSearchParam(Constants.defaultMetricType, queryVectors, top_k_new, n_probe);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
SearchResult res_search = client.search(searchParam);
|
||||
Assert.assertEquals(res_search.getResultIdsList().size(), Constants.nq);
|
||||
Assert.assertEquals(res_search.getResultDistancesList().size(), Constants.nq);
|
||||
Assert.assertEquals(res_search.getResultIdsList().get(0).size(), nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchInvalidTopK(MilvusClient client, String collectionName) {
|
||||
int top_k = -1;
|
||||
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultEntities).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
List<Long> ids = res.getEntityIds();
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
// Index index = new Index.Builder(collectionName, Constants.floatFieldName).withParamsInJson(Constants.indexParam).build();
|
||||
// Response res_create = client.createIndex(index);
|
||||
String dsl = Utils.setSearchParam(Constants.defaultMetricType, Constants.vectors.subList(0, nq), top_k, n_probe);
|
||||
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert(!res_search.getResponse().ok());
|
||||
String dsl = Utils.setSearchParam(Constants.defaultMetricType, queryVectors, top_k, n_probe);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
// // Binary tests
|
||||
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// public void testSearchCollectionNotExistedBinary(MilvusClient client, String collectionName) {
|
||||
// String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
// SearchParam searchParam = new SearchParam.Builder(collectionNameNew)
|
||||
// .withBinaryVectors(queryVectorsBinary)
|
||||
// .withParamsInJson(searchParamStr)
|
||||
// .withTopK(top_k).build();
|
||||
// SearchResponse res_search = client.search(searchParam);
|
||||
// assert (!res_search.getResponse().ok());
|
||||
// }
|
||||
//
|
||||
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// public void test_search_index_IVFLAT_binary(MilvusClient client, String collectionName) {
|
||||
// IndexType indexType = IndexType.IVFLAT;
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
|
||||
// InsertResponse res = client.insert(insertParam);
|
||||
// client.flush(collectionName);
|
||||
// Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
|
||||
// client.createIndex(index);
|
||||
// SearchParam searchParam = new SearchParam.Builder(collectionName)
|
||||
// .withBinaryVectors(queryVectorsBinary)
|
||||
// .withParamsInJson(searchParamStr)
|
||||
// .withTopK(top_k).build();
|
||||
// List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
// Assert.assertEquals(res_search.size(), nq);
|
||||
// Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
// }
|
||||
//
|
||||
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// public void test_search_ids_IVFLAT_binary(MilvusClient client, String collectionName) {
|
||||
// IndexType indexType = IndexType.IVFLAT;
|
||||
// List<Long> vectorIds;
|
||||
// vectorIds = Stream.iterate(0L, n -> n)
|
||||
// .limit(nb)
|
||||
// .collect(Collectors.toList());
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).withVectorIds(vectorIds).build();
|
||||
// InsertResponse res = client.insert(insertParam);
|
||||
// Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
|
||||
// client.createIndex(index);
|
||||
// SearchParam searchParam = new SearchParam.Builder(collectionName)
|
||||
// .withBinaryVectors(queryVectorsBinary)
|
||||
// .withParamsInJson(searchParamStr)
|
||||
// .withTopK(top_k).build();
|
||||
// List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
// Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L);
|
||||
// }
|
||||
//
|
||||
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// public void test_search_partition_not_existed_binary(MilvusClient client, String collectionName) {
|
||||
// IndexType indexType = IndexType.IVFLAT;
|
||||
// String tag = RandomStringUtils.randomAlphabetic(10);
|
||||
// Response createpResponse = client.createPartition(collectionName, tag);
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
|
||||
// InsertResponse res = client.insert(insertParam);
|
||||
// String tagNew = RandomStringUtils.randomAlphabetic(10);
|
||||
// List<String> queryTags = new ArrayList<>();
|
||||
// queryTags.add(tagNew);
|
||||
// SearchParam searchParam = new SearchParam.Builder(collectionName)
|
||||
// .withBinaryVectors(queryVectorsBinary)
|
||||
// .withParamsInJson(searchParamStr)
|
||||
// .withPartitionTags(queryTags)
|
||||
// .withTopK(top_k).build();
|
||||
// SearchResponse res_search = client.search(searchParam);
|
||||
// assert (!res_search.getResponse().ok());
|
||||
// Assert.assertEquals(res_search.getQueryResultsList().size(), 0);
|
||||
// }
|
||||
//
|
||||
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// public void test_search_invalid_n_probe_binary(MilvusClient client, String collectionName) {
|
||||
// int n_probe_new = 0;
|
||||
// String searchParamStrNew = Utils.setSearchParam(n_probe_new);
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
|
||||
// client.insert(insertParam);
|
||||
// SearchParam searchParam = new SearchParam.Builder(collectionName)
|
||||
// .withBinaryVectors(queryVectorsBinary)
|
||||
// .withParamsInJson(searchParamStrNew)
|
||||
// .withTopK(top_k).build();
|
||||
// SearchResponse res_search = client.search(searchParam);
|
||||
// assert (res_search.getResponse().ok());
|
||||
// }
|
||||
//
|
||||
// @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
// public void test_search_invalid_top_k_binary(MilvusClient client, String collectionName) {
|
||||
// int top_k_new = 0;
|
||||
// InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
|
||||
// client.insert(insertParam);
|
||||
// SearchParam searchParam = new SearchParam.Builder(collectionName)
|
||||
// .withBinaryVectors(queryVectorsBinary)
|
||||
// .withParamsInJson(searchParamStr)
|
||||
// .withTopK(top_k_new).build();
|
||||
// SearchResponse res_search = client.search(searchParam);
|
||||
// assert (!res_search.getResponse().ok());
|
||||
// }
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchMultiMust(MilvusClient client, String collectionName){
|
||||
JSONObject vectorParam = Utils.genVectorParam(Constants.defaultMetricType, Constants.vectors.subList(0,nq), top_k, n_probe);
|
||||
JSONObject boolParam = new JSONObject();
|
||||
JSONObject mustParam = new JSONObject();
|
||||
JSONArray jsonArray = new JSONArray();
|
||||
jsonArray.add(vectorParam);
|
||||
JSONObject mustParam1 = new JSONObject();
|
||||
mustParam1.put("must", jsonArray);
|
||||
JSONArray jsonArray1 = new JSONArray();
|
||||
jsonArray1.add(mustParam1);
|
||||
mustParam.put("must", jsonArray1);
|
||||
boolParam.put("bool", mustParam);
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
String dsl = boolParam.toJSONString();
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
SearchResult resSearch = client.search(searchParam);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().size(), Constants.nq);
|
||||
Assert.assertEquals(resSearch.getResultDistancesList().size(), Constants.nq);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().get(0).size(), Constants.topk);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = InvalidDsl.class)
|
||||
public void testSearchMultiVectors(MilvusClient client, String collectionName) {
|
||||
String dsl = String.format(
|
||||
"{\"bool\": {"
|
||||
+ "\"must\": [{"
|
||||
+ " \"range\": {"
|
||||
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
|
||||
+ " }},{"
|
||||
+ " \"vector\": {"
|
||||
+ " \"float_vector\": {"
|
||||
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
|
||||
+ " }}},{"
|
||||
+ " \"vector\": {"
|
||||
+ " \"float_vector\": {"
|
||||
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}\"\n"
|
||||
+ " }}}]}}",
|
||||
top_k, queryVectors, top_k, queryVectors);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = InvalidDsl.class)
|
||||
public void testSearchNoVectors(MilvusClient client, String collectionName) {
|
||||
String dsl = String.format(
|
||||
"{\"bool\": {"
|
||||
+ "\"must\": [{"
|
||||
+ " \"range\": {"
|
||||
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
|
||||
+ " }},{"
|
||||
+ " \"vector\": {"
|
||||
+ " \"float_vector\": {"
|
||||
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"params\": {\"nprobe\": 20}"
|
||||
+ " }}}]}}",
|
||||
top_k);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testSearchVectorNotExisted(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
String dsl = String.format(
|
||||
"{\"bool\": {"
|
||||
+ "\"must\": [{"
|
||||
+ " \"range\": {"
|
||||
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
|
||||
+ " }},{"
|
||||
+ " \"vector\": {"
|
||||
+ " \"float_vector\": {"
|
||||
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
|
||||
+ " }}}]}}",
|
||||
top_k, new ArrayList<>());
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
SearchResult resSearch = client.search(searchParam);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().size(), 0);
|
||||
Assert.assertEquals(resSearch.getResultDistancesList().size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchVectorDifferentDim(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
List<List<Float>> query = Utils.genVectors(nq,64, false);
|
||||
String dsl = String.format(
|
||||
"{\"bool\": {"
|
||||
+ "\"must\": [{"
|
||||
+ " \"range\": {"
|
||||
+ " \"int64\": {\"GT\": -10, \"LT\": 1000}"
|
||||
+ " }},{"
|
||||
+ " \"vector\": {"
|
||||
+ " \"float_vector\": {"
|
||||
+ " \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": %s, \"params\": {\"nprobe\": 20}"
|
||||
+ " }}}]}}",
|
||||
top_k, query);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
|
||||
public void testAsyncSearch(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(floatDsl);
|
||||
ListenableFuture<SearchResult> searchResFuture = client.searchAsync(searchParam);
|
||||
Futures.addCallback(
|
||||
searchResFuture, new FutureCallback<SearchResult>() {
|
||||
@Override
|
||||
public void onSuccess(SearchResult searchResult) {
|
||||
Assert.assertNotNull(searchResult);
|
||||
Assert.assertEquals(searchResult.getResultIdsList().size(), Constants.nq);
|
||||
Assert.assertEquals(searchResult.getResultDistancesList().size(), Constants.nq);
|
||||
Assert.assertEquals(searchResult.getResultIdsList().get(0).size(), Constants.topk);
|
||||
Assert.assertEquals(searchResult.getFieldsMap().get(0).size(), top_k);
|
||||
}
|
||||
@Override
|
||||
public void onFailure(Throwable t) {
|
||||
System.out.println(t.getMessage());
|
||||
}
|
||||
}, MoreExecutors.directExecutor()
|
||||
);
|
||||
}
|
||||
|
||||
// Binary tests
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchCollectionNotExistedBinary(MilvusClient client, String collectionName) {
|
||||
String collectionNameNew = Utils.genUniqueStr(collectionName);
|
||||
SearchParam searchParam = SearchParam.create(collectionNameNew).setDsl(binaryDsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testSearchCollectionBinary(MilvusClient client, String collectionName) {
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
List<String> vectorsToSearch = vectors.subList(0, Constants.nq)
|
||||
.stream().map(byteBuffer -> Arrays.toString(byteBuffer.array()))
|
||||
.collect(Collectors.toList());
|
||||
String dsl = String.format(
|
||||
"{\"bool\": {"
|
||||
+ "\"must\": [{"
|
||||
+ " \"vector\": {"
|
||||
+ " \"binary_vector\": {"
|
||||
+ " \"topk\": %d, \"metric_type\": \"JACCARD\", \"type\": \"binary\", \"query\": %s, \"params\": {\"nprobe\": 20}"
|
||||
+ " }}}]}}",
|
||||
top_k, vectorsToSearch.toString());
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
SearchResult resSearch = client.search(searchParam);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().size(), nq);
|
||||
Assert.assertEquals(resSearch.getResultDistancesList().size(), nq);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testSearchIVFLATBinary(MilvusClient client, String collectionName) {
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
List<Long> entityIds = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.binaryVectorFieldName)
|
||||
.setIndexType(IndexType.BIN_FLAT)
|
||||
.setMetricType(Constants.defaultBinaryMetricType)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
SearchParam searchParam = SearchParam
|
||||
.create(collectionName)
|
||||
.setDsl(Utils.setBinarySearchParam(Constants.defaultBinaryMetricType, vectors.subList(0, Constants.nq), Constants.topk, n_probe));
|
||||
SearchResult resSearch = client.search(searchParam);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().size(), nq);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().get(0).size(), top_k);
|
||||
Assert.assertEquals(resSearch.getResultIdsList().get(0).get(0), entityIds.get(0));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
|
||||
public void testSearchPartitionNotExistedBinary(MilvusClient client, String collectionName) {
|
||||
String tag = Utils.genUniqueStr("tag");
|
||||
client.createPartition(collectionName, tag);
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
String tagNew = Utils.genUniqueStr("tagNew");
|
||||
List<String> queryTags = new ArrayList<>();
|
||||
queryTags.add(tagNew);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(binaryDsl).setPartitionTags(queryTags);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
// #3656
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchInvalidNProbeBinary(MilvusClient client, String collectionName) {
|
||||
int n_probe_new = 0;
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Index index = Index
|
||||
.create(collectionName, Constants.binaryVectorFieldName)
|
||||
.setIndexType(Constants.defaultBinaryIndexType)
|
||||
.setMetricType(Constants.defaultBinaryMetricType)
|
||||
.setParamsInJson(new JsonBuilder().param("nlist", Constants.n_list).build());
|
||||
client.createIndex(index);
|
||||
String dsl = Utils.setBinarySearchParam(Constants.defaultBinaryMetricType, vectors.subList(0, Constants.nq), top_k, n_probe_new);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class, expectedExceptions = ServerSideMilvusException.class)
|
||||
public void testSearchInvalidTopKBinary(MilvusClient client, String collectionName) {
|
||||
int top_k = -1;
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
List<ByteBuffer> vectors = Utils.genBinaryVectors(Constants.nb, Constants.dimension);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, vectors);
|
||||
client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
String dsl = Utils.setBinarySearchParam(Constants.defaultBinaryMetricType, vectors.subList(0, Constants.nq), top_k, n_probe);
|
||||
SearchParam searchParam = SearchParam.create(collectionName).setDsl(dsl);
|
||||
client.search(searchParam);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
package com;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import io.milvus.client.*;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import org.testng.Assert;
|
||||
|
||||
public class Utils {
|
||||
|
||||
|
@ -27,14 +27,14 @@ public class Utils {
|
|||
}
|
||||
|
||||
public static List<List<Float>> genVectors(int vectorCount, int dimension, boolean norm) {
|
||||
List<List<Float>> vectors = new ArrayList<>();
|
||||
Random random = new Random();
|
||||
List<List<Float>> vectors = new ArrayList<>();
|
||||
for (int i = 0; i < vectorCount; ++i) {
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (int j = 0; j < dimension; ++j) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
if (norm == true) {
|
||||
if (norm) {
|
||||
vector = normalize(vector);
|
||||
}
|
||||
vectors.add(vector);
|
||||
|
@ -42,12 +42,12 @@ public class Utils {
|
|||
return vectors;
|
||||
}
|
||||
|
||||
static List<ByteBuffer> genBinaryVectors(long vectorCount, long dimension) {
|
||||
static List<ByteBuffer> genBinaryVectors(int vectorCount, int dimension) {
|
||||
Random random = new Random();
|
||||
List<ByteBuffer> vectors = new ArrayList<>();
|
||||
final long dimensionInByte = dimension / 8;
|
||||
for (long i = 0; i < vectorCount; ++i) {
|
||||
ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte);
|
||||
List<ByteBuffer> vectors = new ArrayList<>(vectorCount);
|
||||
final int dimensionInByte = dimension / 8;
|
||||
for (int i = 0; i < vectorCount; ++i) {
|
||||
ByteBuffer byteBuffer = ByteBuffer.allocate(dimensionInByte);
|
||||
random.nextBytes(byteBuffer.array());
|
||||
vectors.add(byteBuffer);
|
||||
}
|
||||
|
@ -57,10 +57,10 @@ public class Utils {
|
|||
private static List<Map<String, Object>> genBaseFieldsWithoutVector(){
|
||||
List<Map<String,Object>> fieldsList = new ArrayList<>();
|
||||
Map<String, Object> intFields = new HashMap<>();
|
||||
intFields.put("field","int64");
|
||||
intFields.put(Constants.fieldNameKey,Constants.intFieldName);
|
||||
intFields.put("type",DataType.INT64);
|
||||
Map<String, Object> floatField = new HashMap<>();
|
||||
floatField.put("field","float");
|
||||
floatField.put(Constants.fieldNameKey,Constants.floatFieldName);
|
||||
floatField.put("type",DataType.FLOAT);
|
||||
fieldsList.add(intFields);
|
||||
fieldsList.add(floatField);
|
||||
|
@ -72,10 +72,10 @@ public class Utils {
|
|||
List<Map<String, Object>> defaultFieldList = genBaseFieldsWithoutVector();
|
||||
Map<String, Object> vectorField = new HashMap<>();
|
||||
if (isBinary){
|
||||
vectorField.put("field","binary_vector");
|
||||
vectorField.put(Constants.fieldNameKey, Constants.binaryVectorFieldName);
|
||||
vectorField.put("type",DataType.VECTOR_BINARY);
|
||||
}else {
|
||||
vectorField.put("field","float_vector");
|
||||
vectorField.put(Constants.fieldNameKey, Constants.floatVectorFieldName);
|
||||
vectorField.put("type",DataType.VECTOR_FLOAT);
|
||||
}
|
||||
JSONObject jsonObject = new JSONObject();
|
||||
|
@ -86,55 +86,33 @@ public class Utils {
|
|||
return defaultFieldList;
|
||||
}
|
||||
|
||||
public static List<Map<String,Object>> genDefaultEntities(int dimension, int vectorCount, List<List<Float>> vectors){
|
||||
List<Map<String,Object>> fieldsMap = genDefaultFields(dimension, false);
|
||||
public static Map<String, List> genDefaultEntities(int vectorCount, List<List<Float>> vectors){
|
||||
// Map<String,Object> fieldsMap = genDefaultFields(dimension, false);
|
||||
Map<String, List> fieldsMap =new HashMap<>();
|
||||
List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
for (int i = 0; i < vectorCount; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
for(Map<String,Object> field: fieldsMap){
|
||||
String fieldType = field.get("field").toString();
|
||||
switch (fieldType){
|
||||
case "int64":
|
||||
field.put("values",intValues);
|
||||
break;
|
||||
case "float":
|
||||
field.put("values",floatValues);
|
||||
break;
|
||||
case "float_vector":
|
||||
field.put("values",vectors);
|
||||
break;
|
||||
}
|
||||
}
|
||||
fieldsMap.put(Constants.intFieldName,intValues);
|
||||
fieldsMap.put(Constants.floatFieldName,floatValues);
|
||||
fieldsMap.put(Constants.floatVectorFieldName,vectors);
|
||||
return fieldsMap;
|
||||
}
|
||||
|
||||
public static List<Map<String,Object>> genDefaultBinaryEntities(int dimension, int vectorCount, List<ByteBuffer> vectorsBinary){
|
||||
List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
|
||||
public static Map<String, List> genDefaultBinaryEntities(int vectorCount, List<ByteBuffer> vectorsBinary){
|
||||
// List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
|
||||
Map<String, List> binaryFieldsMap =new HashMap<>();
|
||||
List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// List<List<Float>> vectors = genVectors(vectorCount,dimension,false);
|
||||
// List<ByteBuffer> binaryVectors = genBinaryVectors(vectorCount,dimension);
|
||||
for (int i = 0; i < vectorCount; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
for(Map<String,Object> field: binaryFieldsMap){
|
||||
String fieldType = field.get("field").toString();
|
||||
switch (fieldType){
|
||||
case "int64":
|
||||
field.put("values",intValues);
|
||||
break;
|
||||
case "float":
|
||||
field.put("values",floatValues);
|
||||
break;
|
||||
case "binary_vector":
|
||||
field.put("values",vectorsBinary);
|
||||
break;
|
||||
}
|
||||
}
|
||||
binaryFieldsMap.put(Constants.intFieldName,intValues);
|
||||
binaryFieldsMap.put(Constants.floatFieldName,floatValues);
|
||||
binaryFieldsMap.put(Constants.binaryVectorFieldName,vectorsBinary);
|
||||
return binaryFieldsMap;
|
||||
}
|
||||
|
||||
|
@ -147,7 +125,7 @@ public class Utils {
|
|||
return indexParams;
|
||||
}
|
||||
|
||||
public static String setSearchParam(String metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
|
||||
static JSONObject genVectorParam(MetricType metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
|
||||
JSONObject searchParam = new JSONObject();
|
||||
JSONObject fieldParam = new JSONObject();
|
||||
fieldParam.put("topk", topk);
|
||||
|
@ -158,34 +136,53 @@ public class Utils {
|
|||
tmpSearchParam.put("nprobe", nprobe);
|
||||
fieldParam.put("params", tmpSearchParam);
|
||||
JSONObject vectorParams = new JSONObject();
|
||||
vectorParams.put(Constants.floatFieldName, fieldParam);
|
||||
vectorParams.put(Constants.floatVectorFieldName, fieldParam);
|
||||
searchParam.put("vector", vectorParams);
|
||||
JSONObject param = new JSONObject();
|
||||
JSONObject mustParam = new JSONObject();
|
||||
JSONArray tmp = new JSONArray();
|
||||
tmp.add(searchParam);
|
||||
mustParam.put("must", tmp);
|
||||
param.put("bool", mustParam);
|
||||
return JSONObject.toJSONString(param);
|
||||
return searchParam;
|
||||
}
|
||||
|
||||
public static String setBinarySearchParam(String metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
|
||||
static JSONObject genBinaryVectorParam(MetricType metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
|
||||
JSONObject searchParam = new JSONObject();
|
||||
JSONObject fieldParam = new JSONObject();
|
||||
fieldParam.put("topk", topk);
|
||||
fieldParam.put("metricType", metricType);
|
||||
fieldParam.put("queryVectors", queryVectors);
|
||||
fieldParam.put("metric_type", metricType);
|
||||
List<List<Byte>> vectorsToSearch = new ArrayList<>();
|
||||
for (ByteBuffer byteBuffer : queryVectors) {
|
||||
byte[] b = new byte[byteBuffer.remaining()];
|
||||
byteBuffer.get(b);
|
||||
vectorsToSearch.add(Arrays.asList(ArrayUtils.toObject(b)));
|
||||
}
|
||||
fieldParam.put("query", vectorsToSearch);
|
||||
fieldParam.put("type", Constants.binaryVectorType);
|
||||
JSONObject tmpSearchParam = new JSONObject();
|
||||
tmpSearchParam.put("nprobe", nprobe);
|
||||
fieldParam.put("params", tmpSearchParam);
|
||||
JSONObject vectorParams = new JSONObject();
|
||||
vectorParams.put(Constants.floatFieldName, fieldParam);
|
||||
vectorParams.put(Constants.binaryVectorFieldName, fieldParam);
|
||||
searchParam.put("vector", vectorParams);
|
||||
return searchParam;
|
||||
}
|
||||
|
||||
public static String setSearchParam(MetricType metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
|
||||
JSONObject searchParam = genVectorParam(metricType, queryVectors, topk, nprobe);
|
||||
JSONObject boolParam = new JSONObject();
|
||||
JSONObject mustParam = new JSONObject();
|
||||
mustParam.put("must", new JSONArray().add(searchParam));
|
||||
JSONArray tmp = new JSONArray();
|
||||
tmp.add(searchParam);
|
||||
mustParam.put("must", tmp);
|
||||
boolParam.put("bool", mustParam);
|
||||
return JSONObject.toJSONString(searchParam);
|
||||
return JSONObject.toJSONString(boolParam);
|
||||
}
|
||||
|
||||
public static String setBinarySearchParam(MetricType metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
|
||||
JSONObject searchParam = genBinaryVectorParam(metricType, queryVectors, topk, nprobe);
|
||||
JSONObject boolParam = new JSONObject();
|
||||
JSONObject mustParam = new JSONObject();
|
||||
JSONArray tmp = new JSONArray();
|
||||
tmp.add(searchParam);
|
||||
mustParam.put("must", tmp);
|
||||
boolParam.put("bool", mustParam);
|
||||
return JSONObject.toJSONString(boolParam);
|
||||
}
|
||||
|
||||
public static int getIndexParamValue(String indexParam, String key) {
|
||||
|
@ -218,7 +215,7 @@ public class Utils {
|
|||
public static List<Float> getVector(List<Map<String,Object>> entities, int i){
|
||||
List<Float> vector = new ArrayList<>();
|
||||
entities.forEach(entity -> {
|
||||
if("float_vector".equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){
|
||||
if(Constants.floatVectorFieldName.equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){
|
||||
vector.add(((List<Float>)entity.get("values")).get(i));
|
||||
}
|
||||
});
|
||||
|
@ -239,4 +236,237 @@ public class Utils {
|
|||
throw new RuntimeException("unsupported type");
|
||||
}
|
||||
|
||||
public static InsertParam genInsertParam(String collectionName) {
|
||||
Map<String, List> entities = Constants.defaultEntities;
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, entities.get(Constants.intFieldName))
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, entities.get(Constants.floatFieldName))
|
||||
.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, entities.get(Constants.floatVectorFieldName));
|
||||
return insertParam;
|
||||
}
|
||||
|
||||
public static InsertParam genBinaryInsertParam(String collectionName) {
|
||||
List<Long> intValues = new ArrayList<>(Constants.nb);
|
||||
List<Float> floatValues = new ArrayList<>(Constants.nb);
|
||||
for (int i = 0; i < Constants.nb; ++i) {
|
||||
intValues.add((long) i);
|
||||
floatValues.add((float) i);
|
||||
}
|
||||
InsertParam insertParam = InsertParam
|
||||
.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64, intValues)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT, floatValues)
|
||||
.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, Utils.genBinaryVectors(Constants.nb, Constants.dimension));
|
||||
return insertParam;
|
||||
}
|
||||
|
||||
public static CollectionMapping genCreateCollectionMapping(String collectionName, Boolean autoId, Boolean isBinary) {
|
||||
CollectionMapping cm = CollectionMapping.create(collectionName)
|
||||
.addField(Constants.intFieldName, DataType.INT64)
|
||||
.addField(Constants.floatFieldName, DataType.FLOAT)
|
||||
.setParamsInJson(new JsonBuilder()
|
||||
.param("segment_row_limit", Constants.segmentRowLimit)
|
||||
.param("auto_id", autoId)
|
||||
.build());
|
||||
if (isBinary) {
|
||||
cm.addVectorField(Constants.binaryVectorFieldName, DataType.VECTOR_BINARY, Constants.dimension);
|
||||
} else {
|
||||
cm.addVectorField(Constants.floatVectorFieldName, DataType.VECTOR_FLOAT, Constants.dimension);
|
||||
}
|
||||
return cm;
|
||||
}
|
||||
|
||||
public static List<Long> initData(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Assert.assertEquals(client.countEntities(collectionName), Constants.nb);
|
||||
return ids;
|
||||
}
|
||||
|
||||
public static List<Long> initBinaryData(MilvusClient client, String collectionName) {
|
||||
InsertParam insertParam = Utils.genBinaryInsertParam(collectionName);
|
||||
List<Long> ids = client.insert(insertParam);
|
||||
client.flush(collectionName);
|
||||
Assert.assertEquals(client.countEntities(collectionName), Constants.nb);
|
||||
return ids;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// public static CollectionMapping genDefaultCollectionMapping(String collectionName, int dimension,
|
||||
// int segmentRowCount, boolean isBinary) {
|
||||
// Map<String, Object> vectorFieldMap;
|
||||
// if (isBinary) {
|
||||
// vectorFieldMap = new FieldBuilder("binary_vector", DataType.VECTOR_BINARY)
|
||||
// .param("dim", dimension)
|
||||
// .build();
|
||||
// } else {
|
||||
// vectorFieldMap = new FieldBuilder("float_vector", DataType.VECTOR_FLOAT)
|
||||
// .param("dim", dimension)
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// return new CollectionMapping.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64).build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT).build())
|
||||
// .field(vectorFieldMap)
|
||||
// .withParamsInJson(new JsonBuilder()
|
||||
// .param("segment_row_count", segmentRowCount)
|
||||
// .build())
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static InsertParam genDefaultInsertParam(String collectionName, int dimension, int vectorCount,
|
||||
// List<List<Float>> vectors) {
|
||||
// List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
// List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// for (int i = 0; i < vectorCount; ++i) {
|
||||
// intValues.add((long) i);
|
||||
// floatValues.add((float) i);
|
||||
// }
|
||||
//
|
||||
// return new InsertParam.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64)
|
||||
// .values(intValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT)
|
||||
// .values(floatValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float_vector", DataType.VECTOR_FLOAT)
|
||||
// .values(vectors)
|
||||
// .param("dim", dimension)
|
||||
// .build())
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static InsertParam genDefaultInsertParam(String collectionName, int dimension, int vectorCount,
|
||||
// List<List<Float>> vectors, List<Long> entityIds) {
|
||||
// List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
// List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// for (int i = 0; i < vectorCount; ++i) {
|
||||
// intValues.add((long) i);
|
||||
// floatValues.add((float) i);
|
||||
// }
|
||||
//
|
||||
// return new InsertParam.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64)
|
||||
// .values(intValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT)
|
||||
// .values(floatValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float_vector", DataType.VECTOR_FLOAT)
|
||||
// .values(vectors)
|
||||
// .param("dim", dimension)
|
||||
// .build())
|
||||
// .withEntityIds(entityIds)
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static InsertParam genDefaultInsertParam(String collectionName, int dimension, int vectorCount,
|
||||
// List<List<Float>> vectors, String tag) {
|
||||
// List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
// List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// for (int i = 0; i < vectorCount; ++i) {
|
||||
// intValues.add((long) i);
|
||||
// floatValues.add((float) i);
|
||||
// }
|
||||
//
|
||||
// return new InsertParam.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64)
|
||||
// .values(intValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT)
|
||||
// .values(floatValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float_vector", DataType.VECTOR_FLOAT)
|
||||
// .values(vectors)
|
||||
// .param("dim", dimension)
|
||||
// .build())
|
||||
// .withPartitionTag(tag)
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount,
|
||||
// List<List<Byte>> vectorsBinary) {
|
||||
// List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
// List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// for (int i = 0; i < vectorCount; ++i) {
|
||||
// intValues.add((long) i);
|
||||
// floatValues.add((float) i);
|
||||
// }
|
||||
//
|
||||
// return new InsertParam.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64)
|
||||
// .values(intValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT)
|
||||
// .values(floatValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("binary_vector", DataType.VECTOR_BINARY)
|
||||
// .values(vectorsBinary)
|
||||
// .param("dim", dimension)
|
||||
// .build())
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount,
|
||||
// List<List<Byte>> vectorsBinary, List<Long> entityIds) {
|
||||
// List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
// List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// for (int i = 0; i < vectorCount; ++i) {
|
||||
// intValues.add((long) i);
|
||||
// floatValues.add((float) i);
|
||||
// }
|
||||
//
|
||||
// return new InsertParam.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64)
|
||||
// .values(intValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT)
|
||||
// .values(floatValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("binary_vector", DataType.VECTOR_BINARY)
|
||||
// .values(vectorsBinary)
|
||||
// .param("dim", dimension)
|
||||
// .build())
|
||||
// .withEntityIds(entityIds)
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount,
|
||||
// List<List<Byte>> vectorsBinary, String tag) {
|
||||
// List<Long> intValues = new ArrayList<>(vectorCount);
|
||||
// List<Float> floatValues = new ArrayList<>(vectorCount);
|
||||
// for (int i = 0; i < vectorCount; ++i) {
|
||||
// intValues.add((long) i);
|
||||
// floatValues.add((float) i);
|
||||
// }
|
||||
//
|
||||
// return new InsertParam.Builder(collectionName)
|
||||
// .field(new FieldBuilder("int64", DataType.INT64)
|
||||
// .values(intValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("float", DataType.FLOAT)
|
||||
// .values(floatValues)
|
||||
// .build())
|
||||
// .field(new FieldBuilder("binary_vector", DataType.VECTOR_BINARY)
|
||||
// .values(vectorsBinary)
|
||||
// .param("dim", dimension)
|
||||
// .build())
|
||||
// .withPartitionTag(tag)
|
||||
// .build();
|
||||
// }
|
||||
//
|
||||
// public static Index genDefaultIndex(String collectionName, String fieldName, String indexType, String metricType, int nlist) {
|
||||
// return new Index.Builder(collectionName, fieldName)
|
||||
// .withParamsInJson(new JsonBuilder()
|
||||
// .param("index_type", indexType)
|
||||
// .param("metric_type", metricType)
|
||||
// .indexParam("nlist", nlist)
|
||||
// .build())
|
||||
// .build();
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
node_modules
|
||||
npm-debug.log
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
.git
|
||||
.gitignore
|
||||
.env
|
||||
*/bin
|
||||
*/obj
|
||||
README.md
|
||||
LICENSE
|
||||
.vscode
|
||||
__pycache__
|
|
@ -0,0 +1,13 @@
|
|||
.python-version
|
||||
.pytest_cache
|
||||
__pycache__
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
test_out/
|
||||
*.pyc
|
||||
|
||||
db/
|
||||
logs/
|
||||
|
||||
.coverage
|
|
@ -0,0 +1,15 @@
|
|||
FROM python:3.6.8-jessie
|
||||
|
||||
LABEL Name=megasearch_engine_test Version=0.0.1
|
||||
|
||||
WORKDIR /app
|
||||
COPY . /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libc-dev build-essential && \
|
||||
python3 -m pip install -r requirements.txt && \
|
||||
apt-get remove --purge -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENTRYPOINT [ "/app/docker-entrypoint.sh" ]
|
||||
CMD [ "start" ]
|
|
@ -0,0 +1,62 @@
|
|||
## Requirements
|
||||
* python 3.6.8+
|
||||
* pip install -r requirements.txt
|
||||
|
||||
## How to Build Test Env
|
||||
```shell
|
||||
sudo docker pull registry.zilliz.com/milvus/milvus-test:v0.2
|
||||
sudo docker run -it -v /home/zilliz:/home/zilliz -d registry.zilliz.com/milvus/milvus-test:v0.2
|
||||
```
|
||||
|
||||
## How to Create Test Env docker in k8s
|
||||
```shell
|
||||
# 1. start milvus k8s pod
|
||||
cd milvus-helm/charts/milvus
|
||||
helm install --wait --timeout 300s \
|
||||
--set image.repository=registry.zilliz.com/milvus/engine \
|
||||
--set persistence.enabled=true \
|
||||
--set image.tag=PR-3818-gpu-centos7-release \
|
||||
--set image.pullPolicy=Always \
|
||||
--set service.type=LoadBalancer \
|
||||
-f ci/db_backend/mysql_gpu_values.yaml \
|
||||
-f ci/filebeat/values.yaml \
|
||||
-f test.yaml \
|
||||
--namespace milvus \
|
||||
milvus-ci-pr-3818-1-single-centos7-gpu .
|
||||
|
||||
# 2. remove milvus k8s pod
|
||||
helm uninstall -n milvus milvus-test
|
||||
|
||||
# 3. check k8s pod status
|
||||
kubectl get svc -n milvus -w milvus-test
|
||||
|
||||
# 4. login to pod
|
||||
kubectl get pods --namespace milvus
|
||||
kubectl exec -it milvus-test-writable-6cc49cfcd4-rbrns -n milvus bash
|
||||
```
|
||||
|
||||
## How to Run Test cases
|
||||
```shell
|
||||
# Test level-1 cases
|
||||
pytest . --level=1 --ip=127.0.0.1 --port=19530
|
||||
|
||||
# Test level-1 cases in 'test_connect.py' only
|
||||
pytest test_connect.py --level=1
|
||||
```
|
||||
|
||||
## How to list test cases
|
||||
```shell
|
||||
# List all cases
|
||||
pytest --dry-run -qq
|
||||
|
||||
# Collect all cases with docstring
|
||||
pytest --collect-only -qq
|
||||
|
||||
# Create test report with allure
|
||||
pytest --alluredir=test_out . -q -v
|
||||
allure serve test_out
|
||||
```
|
||||
|
||||
## Contribution getting started
|
||||
* Follow PEP-8 for naming and black for formatting.
|
||||
|
|
@ -0,0 +1,323 @@
|
|||
import logging
|
||||
import pdb
|
||||
import json
|
||||
import requests
|
||||
import traceback
|
||||
import utils
|
||||
from milvus import Milvus
|
||||
from utils import *
|
||||
|
||||
url_collections = "collections"
|
||||
url_system = "system/"
|
||||
|
||||
class Request(object):
|
||||
def __init__(self, url):
|
||||
# logging.getLogger().error(url)
|
||||
self._url = url
|
||||
|
||||
def _check_status(self, result):
|
||||
# logging.getLogger().info(result.text)
|
||||
if result.status_code not in [200, 201, 204]:
|
||||
return False
|
||||
if not result.text or "code" not in json.loads(result.text):
|
||||
return True
|
||||
elif json.loads(result.text)["code"] == 0:
|
||||
return True
|
||||
else:
|
||||
logging.getLogger().error(result.status_code)
|
||||
logging.getLogger().error(result.reason)
|
||||
return False
|
||||
|
||||
def get(self, data=None):
|
||||
res_get = requests.get(self._url, params=data)
|
||||
return self._check_status(res_get), json.loads(res_get.text)["data"]
|
||||
|
||||
def get_with_body(self, data=None):
|
||||
res_get = requests.get(self._url, data=json.dumps(data))
|
||||
return self._check_status(res_get), json.loads(res_get.text)["data"]
|
||||
|
||||
def post(self, data):
|
||||
res_post = requests.post(self._url, data=json.dumps(data))
|
||||
if res_post.text:
|
||||
return self._check_status(res_post), json.loads(res_post.text)
|
||||
else:
|
||||
return self._check_status(res_post), res_post
|
||||
|
||||
def delete(self, data=None):
|
||||
if data:
|
||||
res_delete = requests.delete(self._url, data=json.dumps(data))
|
||||
else:
|
||||
res_delete = requests.delete(self._url)
|
||||
return self._check_status(res_delete), res_delete
|
||||
|
||||
def put(self, data=None):
|
||||
if data:
|
||||
res_put = requests.put(self._url, data=json.dumps(data))
|
||||
else:
|
||||
res_put = requests.put(self._url)
|
||||
return self._check_status(res_put), res_put
|
||||
|
||||
|
||||
class MilvusClient(object):
|
||||
def __init__(self, url):
|
||||
logging.getLogger().debug(url)
|
||||
self._url = url
|
||||
|
||||
def create_collection(self, collection_name, fields):
|
||||
url = self._url+url_collections
|
||||
r = Request(url)
|
||||
fields.update({"collection_name": collection_name})
|
||||
try:
|
||||
status, data = r.post(fields)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def list_collections(self, offset=0, page_size=10):
|
||||
url = self._url+url_collections+'?'+'offset='+str(offset)+'&page_size='+str(page_size)
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get()
|
||||
if status:
|
||||
return data["collections"]
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def has_collection(self, collection_name):
|
||||
url = self._url+url_collections+'/'+collection_name
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get()
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
url = self._url+url_collections+'/'+str(collection_name)
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.delete()
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def info_collection(self, collection_name):
|
||||
url = self._url+url_collections+'/'+str(collection_name)
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get()
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def stat_collection(self, collection_name):
|
||||
url = self._url+url_collections+'/'+str(collection_name)
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get(data={"info": "stat"})
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def count_collection(self, collection_name):
|
||||
return self.stat_collection(collection_name)["row_count"]
|
||||
|
||||
def create_partition(self, collection_name, tag):
|
||||
url = self._url+url_collections+'/'+collection_name+'/partitions'
|
||||
r = Request(url)
|
||||
create_params = {"partition_tag": tag}
|
||||
try:
|
||||
status, data = r.post(create_params)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def list_partitions(self, collection_name):
|
||||
url = self._url+url_collections+'/'+collection_name+'/partitions'
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get()
|
||||
if status:
|
||||
return data["partitions"]
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def drop_partition(self, collection_name, tag):
|
||||
url = self._url+url_collections+'/'+collection_name+'/partitions/'+tag;
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.delete()
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def flush(self, collection_names):
|
||||
url = self._url+url_system+'/task'
|
||||
r = Request(url)
|
||||
flush_params = {
|
||||
"flush": {"collection_names": collection_names}}
|
||||
try:
|
||||
status, data = r.put(data=flush_params)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def insert(self, collection_name, entities, tag=None):
|
||||
url = self._url+url_collections+'/'+collection_name+'/entities'
|
||||
r = Request(url)
|
||||
insert_params = {"entities": entities}
|
||||
if tag:
|
||||
insert_params.update({"partition_tag": tag})
|
||||
try:
|
||||
status, data = r.post(insert_params)
|
||||
if status:
|
||||
return data["ids"]
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def delete(self, collection_name, ids):
|
||||
url = self._url+url_collections+'/'+collection_name+'/entities'
|
||||
r = Request(url)
|
||||
delete_params = {"ids": ids}
|
||||
try:
|
||||
status, data = r.delete(data=delete_params)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
'''
|
||||
method: get entities by ids
|
||||
'''
|
||||
def get_entities(self, collection_name, ids):
|
||||
ids = ','.join(str(i) for i in ids)
|
||||
url = self._url+url_collections+'/'+collection_name+'/entities?ids='+ids
|
||||
# url = self._url+url_collections+'/'+collection_name+'/entities'
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get()
|
||||
if status:
|
||||
return data["entities"]
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
'''
|
||||
method: create index
|
||||
'''
|
||||
def create_index(self, collection_name, field_name, index_params):
|
||||
url = self._url+url_collections+'/'+collection_name+'/fields/'+field_name+'/indexes'
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.post(index_params)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def drop_index(self, collection_name, field_name):
|
||||
url = self._url+url_collections+'/'+collection_name+'/fields/'+field_name+'/indexes'
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.delete()
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
def describe_index(self, collection_name, field_name):
|
||||
info = self.info_collection(collection_name)
|
||||
for field in info["fields"]:
|
||||
if field["field_name"] == field_name:
|
||||
return field["index_params"]
|
||||
|
||||
def search(self, collection_name, query_expr, fields=None):
|
||||
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
|
||||
r = Request(url)
|
||||
search_params = {
|
||||
"query": query_expr,
|
||||
"fields": fields
|
||||
}
|
||||
try:
|
||||
status, data = r.get_with_body(search_params)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
||||
|
||||
'''
|
||||
method: drop all collections in db
|
||||
'''
|
||||
def clear_db(self):
|
||||
collections = self.list_collections(page_size=10000)
|
||||
if collections:
|
||||
for item in collections:
|
||||
self.drop_collection(item["collection_name"])
|
||||
|
||||
def system_cmd(self, cmd):
|
||||
url = self._url+url_system+cmd
|
||||
r = Request(url)
|
||||
try:
|
||||
status, data = r.get()["reply"]
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
return False
|
|
@ -0,0 +1,88 @@
|
|||
import pdb
|
||||
import copy
|
||||
import logging
|
||||
import itertools
|
||||
from time import sleep
|
||||
import threading
|
||||
from multiprocessing import Process
|
||||
import sklearn.preprocessing
|
||||
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uid = "create_collection"
|
||||
|
||||
class TestCreateCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_create_collection_segment_row_limit(self, client, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
assert client.create_collection(collection_name, fields)
|
||||
assert client.has_collection(collection_name)
|
||||
|
||||
def test_create_collection_exceed_segment_row_limit(self, client):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
segment_row_limit = 10000000
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = segment_row_limit
|
||||
client.create_collection(collection_name, fields)
|
||||
assert not client.has_collection(collection_name)
|
||||
|
||||
def test_create_collection_id(self, client):
|
||||
'''
|
||||
target: test create id collection
|
||||
method: create collection with auto_id false
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["auto_id"] = False
|
||||
# fields = gen_default_fields(auto_id=False)
|
||||
client.create_collection(collection_name, fields)
|
||||
assert client.has_collection(collection_name)
|
||||
|
||||
def _test_create_binary_collection(self, client):
|
||||
collection_name = 'test_NRHgct0s'
|
||||
fields = {'fields': [{'name': 'int64', 'type': 'INT64'},
|
||||
{'name': 'float', 'type': 'FLOAT'},
|
||||
{'name': 'binary_vector', 'type': 'BINARY_FLOAT', 'params': {'dim': 128}}],
|
||||
'segment_row_limit': 1000, 'auto_id': True}
|
||||
client.create_collection(collection_name, fields)
|
||||
assert client.has_collection(collection_name)
|
|
@ -0,0 +1,55 @@
|
|||
import pytest
|
||||
import time
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
|
||||
class TestDropCollection:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_collection` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_drop_collection(self, client, collection):
|
||||
'''
|
||||
target: test delete collection created with correct params
|
||||
method: create collection and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no collection in collections
|
||||
'''
|
||||
client.drop_collection(collection)
|
||||
time.sleep(2)
|
||||
assert not client.has_collection(collection)
|
||||
|
||||
def test_drop_collection_not_existed(self, client):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db, assert the exception raised returned by drp_collection method
|
||||
expected: raise exception
|
||||
'''
|
||||
collection_name = gen_unique_str()
|
||||
assert not client.drop_collection(collection_name)
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
"12-s",
|
||||
" ",
|
||||
"12 s",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
)
|
||||
def get_invalid_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_drop_collection_with_invalid_collection(self, client, get_invalid_collection_name):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db, assert the exception raised returned by drp_collection method
|
||||
expected: raise exception
|
||||
'''
|
||||
assert not client.drop_collection(get_invalid_collection_name)
|
|
@ -0,0 +1,119 @@
|
|||
import logging
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import copy
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uid = "info_collection"
|
||||
|
||||
|
||||
class TestInfoBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_collection_info` function, no data in collection
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_get_collection_info(self, client, collection):
|
||||
"""
|
||||
target: test get collection info with normal collection
|
||||
method: create collection with default fields and get collection info
|
||||
expected: no exception raised, and value returned correct
|
||||
"""
|
||||
info = client.info_collection(collection)
|
||||
assert info['count'] == 0
|
||||
assert info['auto_id'] == True
|
||||
assert info['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(info["fields"]) == 3
|
||||
for field in info['fields']:
|
||||
if field['type'] == 'INT64':
|
||||
assert field['name'] == default_int_field_name
|
||||
if field['type'] == 'FLOAT':
|
||||
assert field['name'] == default_float_field_name
|
||||
if field['type'] == 'VECTOR_FLOAT':
|
||||
assert field['name'] == default_float_vec_field_name
|
||||
|
||||
def test_get_collection_info_segment_row_limit(self, client, collection):
|
||||
"""
|
||||
target: test get collection info with non-default segment row limit
|
||||
method: create collection with non-default segment row limit and get collection info
|
||||
expected: no exception raised
|
||||
"""
|
||||
segment_row_limit = 4096
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = segment_row_limit
|
||||
client.create_collection(collection_name, fields)
|
||||
client.insert(collection, default_entities)
|
||||
client.flush([collection])
|
||||
info = client.info_collection(collection_name)
|
||||
assert info['segment_row_limit'] == segment_row_limit
|
||||
|
||||
def test_get_collection_info_id_collection(self, client, id_collection):
|
||||
"""
|
||||
target: test get collection info with id collection
|
||||
method: create id collection with auto_id=False and get collection info
|
||||
expected: no exception raised
|
||||
"""
|
||||
info = client.info_collection(id_collection)
|
||||
assert info['count'] == 0
|
||||
assert info['auto_id'] == False
|
||||
assert info['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(info["fields"]) == 3
|
||||
|
||||
def test_get_collection_info_with_collection_not_existed(self, client):
|
||||
"""
|
||||
target: test get collection info with not existed collection
|
||||
method: call get collection info with random collection name which not in db
|
||||
expected: not ok
|
||||
"""
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert not client.info_collection(collection_name)
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
"12-s",
|
||||
" ",
|
||||
"12 s",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
)
|
||||
def get_invalid_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_get_collection_info_collection_name_invalid(self, client, get_invalid_collection_name):
|
||||
collection_name = get_invalid_collection_name
|
||||
assert not client.info_collection(collection_name)
|
||||
|
||||
def test_row_count_after_insert(self, client, collection):
|
||||
"""
|
||||
target: test the change of collection row count after insert data
|
||||
method: insert entities to collection and get collection info
|
||||
expected: row count increase
|
||||
"""
|
||||
info = client.info_collection(collection)
|
||||
assert info['count'] == 0
|
||||
assert client.insert(collection, default_entities)
|
||||
client.flush([collection])
|
||||
info = client.info_collection(collection)
|
||||
assert info['count'] == default_nb
|
||||
|
||||
def test_get_collection_info_after_index_created(self, client, collection):
|
||||
"""
|
||||
target: test index of collection info after index created
|
||||
method: create index and get collection info
|
||||
expected: no exception raised
|
||||
"""
|
||||
index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
res = client.create_index(collection, default_float_vec_field_name, index)
|
||||
info = client.info_collection(collection)
|
||||
for field in info['fields']:
|
||||
if field['name'] == default_float_vec_field_name:
|
||||
assert field['index_params'] == index
|
|
@ -0,0 +1,76 @@
|
|||
import logging
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import copy
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uid = "stats_collection"
|
||||
|
||||
|
||||
class TestStatsBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `collection_stats` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_get_collection_stats(self, client, collection):
|
||||
"""
|
||||
target: get collections stats
|
||||
method: call get_collection_stats with created collection
|
||||
expected: status ok
|
||||
"""
|
||||
client.insert(collection, default_entities)
|
||||
client.flush([collection])
|
||||
stats = client.stat_collection(collection)
|
||||
assert stats['row_count'] == default_nb
|
||||
assert len(stats["partitions"]) == 1
|
||||
assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
assert stats["partitions"][0]["row_count"] == default_nb
|
||||
|
||||
def test_get_collection_stats_collection_not_existed(self, client, collection):
|
||||
"""
|
||||
target: get collection stats when collection not existed
|
||||
method: call collection_stats with a random collection_name, which is not in db
|
||||
expected: status not ok
|
||||
"""
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert not client.stat_collection(collection_name)
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
"12-s",
|
||||
" ",
|
||||
"12 s",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
)
|
||||
def get_invalid_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_get_collection_stats_name_invalid(self, client, collection, get_invalid_collection_name):
|
||||
"""
|
||||
target: get collection stats where collection name is invalid
|
||||
method: call collection_stats with invalid collection_name
|
||||
expected: status not ok
|
||||
"""
|
||||
assert not client.stat_collection(get_invalid_collection_name)
|
||||
|
||||
def test_get_collection_stats_empty_collection(self, client, collection):
|
||||
"""
|
||||
target: get collection stats where no entity in collection
|
||||
method: call collection_stats with empty collection
|
||||
expected: segment = []
|
||||
"""
|
||||
stats = client.stat_collection(collection)
|
||||
assert stats["row_count"] == 0
|
||||
assert len(stats["partitions"]) == 1
|
||||
assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
assert stats["partitions"][0]["row_count"] == 0
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import copy
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uid = "list_collection"
|
||||
|
||||
|
||||
class TestListCollections:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_collections` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_list_collections(self, client, collection):
|
||||
'''
|
||||
target: test list collections
|
||||
method: create collection, assert the value returned by list_collections method
|
||||
expected: True
|
||||
'''
|
||||
collections = map(lambda x: x['collection_name'], client.list_collections())
|
||||
assert collection in collections
|
||||
|
||||
def test_list_collections_not_existed(self, client):
|
||||
'''
|
||||
target: test if collection not created
|
||||
method: random a collection name, which not existed in db, assert the value returned by list_collections method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
collections = map(lambda x: x['collection_name'], client.list_collections())
|
||||
assert collection_name not in collections
|
||||
|
||||
def test_list_collections_no_collection(self, client):
|
||||
'''
|
||||
target: test list collections when no collection in db
|
||||
method: delete all collections and list collections
|
||||
expected: status is ok and len of result is 0
|
||||
'''
|
||||
client.clear_db()
|
||||
result = client.list_collections()
|
||||
assert len(result) == 0
|
||||
|
||||
def test_list_collections_multi_collections(self, client, collection):
|
||||
'''
|
||||
target: test list collections with multi collections
|
||||
method: create multi collections and list them
|
||||
expected: len of list results is equal collection nums
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
assert client.create_collection(collection_name, fields)
|
||||
collections = map(lambda x: x['collection_name'], client.list_collections())
|
||||
assert collection_name in collections
|
||||
assert collection in collections
|
||||
|
||||
def test_list_collections_offset(self, client, collection):
|
||||
'''
|
||||
target: test list collections with offset parameter
|
||||
method: create multi collections and list them with offset
|
||||
expected: first collection with offset=1 equal to second collection with offset=0
|
||||
'''
|
||||
collection_num = 2
|
||||
fields = copy.deepcopy(default_fields)
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert client.create_collection(collection_name, fields)
|
||||
collections = list(map(lambda x: x['collection_name'], client.list_collections()))
|
||||
collections_new = list(map(lambda x: x['collection_name'], client.list_collections(offset=1)))
|
||||
assert collections[1] == collections_new[0]
|
||||
|
||||
def test_list_collections_page_size(self, client, collection):
|
||||
'''
|
||||
target: test list collections with page_size parameter
|
||||
method: create multi collections and list them with page_size
|
||||
expected: collection num equal to page_size
|
||||
'''
|
||||
collection_num = 6
|
||||
page_size = 5
|
||||
fields = copy.deepcopy(default_fields)
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert client.create_collection(collection_name, fields)
|
||||
collections = list(map(lambda x: x['collection_name'], client.list_collections(page_size=page_size)))
|
||||
c = list(map(lambda x: x['collection_name'], client.list_collections()))
|
||||
assert len(collections) == page_size
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
import pdb
|
||||
import logging
|
||||
import socket
|
||||
import pytest
|
||||
import requests
|
||||
from utils import gen_unique_str
|
||||
from client import MilvusClient
|
||||
from utils import *
|
||||
|
||||
timeout = 60
|
||||
dimension = 128
|
||||
delete_timeout = 60
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--ip", action="store", default="localhost")
|
||||
# parser.addoption("--ip", action="store", default="192.168.1.113")
|
||||
parser.addoption("--service", action="store", default="")
|
||||
parser.addoption("--port", action="store", default=19121)
|
||||
parser.addoption("--tag", action="store", default="all", help="only run tests matching the tag.")
|
||||
parser.addoption('--dry-run', action='store_true', default=False)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# register an additional marker
|
||||
config.addinivalue_line(
|
||||
"markers", "tag(name): mark test to run only matching the tag"
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
tags = list()
|
||||
for marker in item.iter_markers(name="tag"):
|
||||
for tag in marker.args:
|
||||
tags.append(tag)
|
||||
if tags:
|
||||
cmd_tag = item.config.getoption("--tag")
|
||||
if cmd_tag != "all" and cmd_tag not in tags:
|
||||
pytest.skip("test requires tag in {!r}".format(tags))
|
||||
|
||||
|
||||
def pytest_runtestloop(session):
|
||||
if session.config.getoption('--dry-run'):
|
||||
for item in session.items:
|
||||
print(item.nodeid)
|
||||
return True
|
||||
|
||||
|
||||
def check_server_connection(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
|
||||
connected = True
|
||||
if ip and (ip not in ['localhost', '127.0.0.1']):
|
||||
try:
|
||||
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
|
||||
except Exception as e:
|
||||
print("Socket connnet failed: %s" % str(e))
|
||||
connected = False
|
||||
return connected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
service_name = request.config.getoption("--service")
|
||||
port = request.config.getoption("--port")
|
||||
url = "http://%s:%s/" % (ip, port)
|
||||
client = MilvusClient(url)
|
||||
args = {"ip": ip, "port": port, "service_name": service_name, "url": url, "client": client}
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(request, args):
|
||||
client = args["client"]
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def collection(request, client):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
create_params = gen_default_fields()
|
||||
try:
|
||||
if not client.create_collection(collection_name, create_params):
|
||||
pytest.exit(str(e))
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
client.clear_db()
|
||||
request.addfinalizer(teardown)
|
||||
assert client.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
# customised id
|
||||
@pytest.fixture(scope="function")
|
||||
def id_collection(request, client):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
client.create_collection(collection_name, gen_default_fields(auto_id=False))
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
client.clear_db()
|
||||
request.addfinalizer(teardown)
|
||||
assert client.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def binary_collection(request, client):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
client.create_collection(collection_name, gen_default_fields(binary=True))
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
client.clear_db()
|
||||
request.addfinalizer(teardown)
|
||||
assert client.has_collection(collection_name)
|
||||
return collection_name
|
||||
|
||||
|
||||
# customised id
|
||||
@pytest.fixture(scope="function")
|
||||
def binary_id_collection(request, client):
|
||||
ori_collection_name = getattr(request.module, "collection_id", "test")
|
||||
collection_name = gen_unique_str(ori_collection_name)
|
||||
try:
|
||||
client.create_collection(collection_name, gen_default_fields(auto_id=False, binary=True))
|
||||
except Exception as e:
|
||||
pytest.exit(str(e))
|
||||
def teardown():
|
||||
client.clear_db()
|
||||
request.addfinalizer(teardown)
|
||||
assert client.has_collection(collection_name)
|
||||
return collection_name
|
|
@ -0,0 +1,10 @@
|
|||
import utils
|
||||
|
||||
default_fields = utils.gen_default_fields()
|
||||
default_binary_fields = utils.gen_default_fields(binary=True)
|
||||
|
||||
default_entity = utils.gen_entities(1)
|
||||
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
|
||||
|
||||
default_entities = utils.gen_entities(utils.default_nb)
|
||||
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ "$1" = 'start' ]; then
|
||||
tail -f /dev/null
|
||||
fi
|
||||
|
||||
exec "$@"
|
|
@ -0,0 +1,89 @@
|
|||
import logging
|
||||
import time
|
||||
import pdb
|
||||
import copy
|
||||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
DELETE_TIMEOUT = 60
|
||||
uid = "test_delete"
|
||||
|
||||
|
||||
class TestDeleteBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `insert` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, client):
|
||||
if str(client.system_cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_delete_entity_id_not_exised(self, client, collection):
|
||||
'''
|
||||
target: test get entity, params entity_id not existed
|
||||
method: get entity
|
||||
expected: result empty
|
||||
'''
|
||||
assert client.insert(collection, default_entity)
|
||||
res_flush = client.flush([collection])
|
||||
entities = client.get_entities(collection, 1)
|
||||
assert entities
|
||||
|
||||
def test_delete_empty_collection(self, client, collection):
|
||||
'''
|
||||
target: test delete entity, params collection_name not existed
|
||||
method: add entity and delete
|
||||
expected: status DELETED
|
||||
'''
|
||||
status = client.delete(collection, ["0"])
|
||||
assert status
|
||||
|
||||
def test_delete_entity_collection_not_existed(self, client, collection):
|
||||
'''
|
||||
target: test delete entity, params collection_name not existed
|
||||
method: add entity and delete
|
||||
expected: error raised
|
||||
'''
|
||||
collection_new = gen_unique_str()
|
||||
status = client.delete(collection_new, ["0"])
|
||||
assert not status
|
||||
|
||||
def test_insert_delete(self, client, collection):
|
||||
'''
|
||||
target: test delete entity
|
||||
method: add entities and delete
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = client.insert(collection, default_entities)
|
||||
client.flush([collection])
|
||||
delete_ids = [ids[0]]
|
||||
status = client.delete(collection, delete_ids)
|
||||
assert status
|
||||
client.flush([collection])
|
||||
res_count = client.count_collection(collection)
|
||||
assert res_count == default_nb - 1
|
|
@ -0,0 +1,99 @@
|
|||
import logging
|
||||
import time
|
||||
import pdb
|
||||
import copy
|
||||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
GET_TIMEOUT = 120
|
||||
uid = "test_get"
|
||||
|
||||
|
||||
class TestGetBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `insert` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, client):
|
||||
if str(client.system_cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_get_entity_id_not_exised(self, client, collection):
|
||||
'''
|
||||
target: test delete entity, params entity_id not existed
|
||||
method: add entity and delete
|
||||
expected: entities empty
|
||||
'''
|
||||
ids = client.insert(collection, default_entity)
|
||||
client.flush([collection])
|
||||
entities = client.get_entities(collection, [0,1])
|
||||
assert entities
|
||||
|
||||
def test_get_empty_collection(self, client, collection):
|
||||
'''
|
||||
target: test hry entity, params collection_name not existed
|
||||
method: add entity and get
|
||||
expected: entities empty
|
||||
'''
|
||||
entities = client.get_entities(collection, [0])
|
||||
assert entities
|
||||
|
||||
def test_get_entity_collection_not_existed(self, client, collection):
|
||||
'''
|
||||
target: test get entity, params collection_name not existed
|
||||
method: add entity and get
|
||||
expected: code error
|
||||
'''
|
||||
collection_new = gen_unique_str()
|
||||
entities = client.get_entities(collection_new, [0])
|
||||
assert not entities
|
||||
|
||||
def test_insert_get(self, client, collection):
|
||||
'''
|
||||
target: test get entity
|
||||
method: add entities and get
|
||||
expected: entity returned
|
||||
'''
|
||||
ids = client.insert(collection, default_entities)
|
||||
client.flush([collection])
|
||||
delete_ids = [ids[0]]
|
||||
entities = client.get_entities(collection, delete_ids)
|
||||
assert len(entities) == 1
|
||||
|
||||
def test_insert_get_batch(self, client, collection):
|
||||
'''
|
||||
target: test get entity
|
||||
method: add entities and get
|
||||
expected: entity returned
|
||||
'''
|
||||
get_length = 10
|
||||
ids = client.insert(collection, default_entities)
|
||||
client.flush([collection])
|
||||
delete_ids = ids[:get_length]
|
||||
entities = client.get_entities(collection, delete_ids)
|
||||
assert len(entities) == get_length
|
|
@ -0,0 +1,158 @@
|
|||
import logging
|
||||
import time
|
||||
import pdb
|
||||
import copy
|
||||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
ADD_TIMEOUT = 60
|
||||
uid = "test_insert"
|
||||
|
||||
class TestInsertBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `insert` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, client):
|
||||
if str(client.system_cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_insert_with_empty_entity(self, client, collection):
|
||||
'''
|
||||
target: test add vectors with empty vectors list
|
||||
method: set empty vectors list as add method params
|
||||
expected: raises a Exception
|
||||
'''
|
||||
assert not client.insert(collection, [])
|
||||
|
||||
def test_insert_with_entity(self, client, collection):
|
||||
'''
|
||||
target: test add vectors with an entity
|
||||
method: insert with an entity
|
||||
expected: count correct
|
||||
'''
|
||||
assert client.insert(collection, default_entity)
|
||||
res_flush = client.flush([collection])
|
||||
count = client.count_collection(collection)
|
||||
assert count == 1
|
||||
|
||||
def test_insert_with_entities(self, client, collection):
|
||||
'''
|
||||
target: test add vectors with entities
|
||||
method: insert entities
|
||||
expected: count correct
|
||||
'''
|
||||
assert client.insert(collection, default_entities)
|
||||
res_flush = client.flush([collection])
|
||||
count = client.count_collection(collection)
|
||||
assert count == default_nb
|
||||
|
||||
def test_insert_with_field_not_match(self, client, collection):
|
||||
'''
|
||||
target: insert field not match with collection schema
|
||||
method: pop a field(int64)
|
||||
expected: insert failed
|
||||
'''
|
||||
entity = copy.deepcopy(default_entity)
|
||||
entity[0].pop("int64")
|
||||
logging.getLogger().info(entity)
|
||||
assert not client.insert(collection, entity)
|
||||
|
||||
def test_insert_with_tag_not_existed(self, client, collection):
|
||||
'''
|
||||
target: test add vectors with an entity
|
||||
method: insert an entity with tag, which not created
|
||||
expected: insert failed
|
||||
'''
|
||||
assert not client.insert(collection, default_entity, tag=default_tag)
|
||||
|
||||
def test_insert_with_tag_not_existed(self, client, collection):
|
||||
'''
|
||||
target: test add vectors with an entity
|
||||
method: insert an entity with tag
|
||||
expected: count correct
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
assert client.insert(collection, default_entity, tag=default_tag)
|
||||
res_flush = client.flush([collection])
|
||||
count = client.count_collection(collection)
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestInsertID:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `insert` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, client):
|
||||
if str(client.system_cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_insert_with_entity_id_not_matched(self, client, id_collection):
|
||||
'''
|
||||
target: test add vectors with an entity
|
||||
method: insert with an entity
|
||||
expected: insert failed
|
||||
'''
|
||||
assert not client.insert(id_collection, default_entity)
|
||||
|
||||
def test_insert_with_entity(self, client, id_collection):
|
||||
'''
|
||||
target: test add vectors with an entity, in a id_collection
|
||||
method: insert with an entity, with customized ids
|
||||
expected: insert success
|
||||
'''
|
||||
entity = copy.deepcopy(default_entity)
|
||||
entity[0].update({"__id": 1})
|
||||
logging.getLogger().info(entity)
|
||||
assert client.insert(id_collection, entity)
|
||||
res_flush = client.flush([id_collection])
|
||||
count = client.count_collection(id_collection)
|
||||
assert count == 1
|
|
@ -0,0 +1,236 @@
|
|||
import logging
|
||||
import pytest
|
||||
import requests
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uid = "test_search"
|
||||
epsilon = 0.001
|
||||
field_name = default_float_vec_field_name
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
default_query, default_query_vectors = gen_query_vectors(field_name, default_entities, default_top_k, default_nq)
|
||||
|
||||
|
||||
def init_data(client, collection, nb=default_nb, partition_tags=None, auto_id=True):
|
||||
insert_entities = default_entities if nb == default_nb else gen_entities(nb)
|
||||
if partition_tags is None:
|
||||
if auto_id:
|
||||
ids = client.insert(collection, insert_entities)
|
||||
else:
|
||||
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)])
|
||||
else:
|
||||
if auto_id:
|
||||
ids = client.insert(collection, insert_entities, partition_tag=partition_tags)
|
||||
else:
|
||||
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
|
||||
client.flush([collection])
|
||||
return insert_entities, ids
|
||||
|
||||
|
||||
def init_binary_data(client, collection, nb=default_nb, partition_tags=None, auto_id=True):
|
||||
"""
|
||||
Generate binary entities and insert to collection
|
||||
"""
|
||||
if nb == default_nb:
|
||||
insert_entities = default_binary_entities
|
||||
insert_raw_vectors = default_raw_binary_vectors
|
||||
else:
|
||||
insert_raw_vectors, insert_entities = gen_binary_entities(nb)
|
||||
if partition_tags is None:
|
||||
if auto_id:
|
||||
ids = client.insert(collection, insert_entities)
|
||||
else:
|
||||
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)])
|
||||
else:
|
||||
if auto_id:
|
||||
ids = client.insert(collection, insert_entities, partition_tag=partition_tags)
|
||||
else:
|
||||
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
|
||||
client.flush([collection])
|
||||
return insert_raw_vectors, insert_entities, ids
|
||||
|
||||
|
||||
def check_id_result(results, ids):
|
||||
ids_res = []
|
||||
for result in results:
|
||||
ids_res.extend([int(item['id']) for item in result])
|
||||
for id in ids:
|
||||
if int(id) not in ids_res:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TestSearchBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `search` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
"""
|
||||
generate top-k params
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[1, 10]
|
||||
)
|
||||
def get_top_k(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[1, 10]
|
||||
)
|
||||
def get_nq(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, connect):
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("sq8h not support in CPU mode")
|
||||
return request.param
|
||||
|
||||
def test_search_flat(self, client, collection, get_nq, get_top_k):
|
||||
"""
|
||||
target: test basic search function, all the search params is correct, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of result is top_k
|
||||
"""
|
||||
top_k = get_top_k
|
||||
nq = get_nq
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
data = client.search(collection, query)
|
||||
res = data['result']
|
||||
assert data['num'] == nq
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert float(res[0][0]['distance']) <= epsilon
|
||||
assert check_id_result(res, ids[:nq])
|
||||
|
||||
# TODO
|
||||
def test_search_invalid_top_k(self, client, collection):
|
||||
"""
|
||||
target: test search with invalid top_k that large than max_top_k
|
||||
method: call search with invalid top_k
|
||||
expected: exception
|
||||
"""
|
||||
top_k = max_top_k + 1
|
||||
nq = 1
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
assert not client.search(collection, query)
|
||||
|
||||
def test_search_fields(self, client, collection, ):
|
||||
"""
|
||||
target: test search with field
|
||||
method: call search with field and check return whether contain field value
|
||||
expected: return field value
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
|
||||
data = client.search(collection, query, fields=[default_int_field_name])
|
||||
res = data['result']
|
||||
assert data['num'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert default_int_field_name in res[0][0]['entity'].keys()
|
||||
|
||||
# TODO
|
||||
def test_search_invalid_n_probe(self, client, collection, ):
|
||||
"""
|
||||
target: test basic search function with invalid n_probe
|
||||
method: call search function
|
||||
expected: not ok
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
assert client.create_index(collection, default_float_vec_field_name, default_index)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq,
|
||||
search_params={"nprobe": 0})
|
||||
assert not client.search(collection, query)
|
||||
|
||||
def test_search_not_existed_collection(self, client, collection):
|
||||
"""
|
||||
target: test basic search with not existed collection
|
||||
method: call search function
|
||||
expected: not ok
|
||||
"""
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert not client.search(collection_name, default_query)
|
||||
|
||||
def test_search_empty_collection(self, client, collection):
|
||||
"""
|
||||
target: test basic search function with empty collection
|
||||
method: call search function
|
||||
expected: return 0 entities
|
||||
"""
|
||||
assert 0 == client.count_collection(collection)
|
||||
data = client.search(collection, default_query)
|
||||
res = data['result']
|
||||
assert data['num'] == 0
|
||||
assert len(res) == 0
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
"12-s",
|
||||
" ",
|
||||
"12 s",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
)
|
||||
def get_invalid_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_get_collection_stats_name_invalid(self, client, collection, get_invalid_collection_name):
|
||||
"""
|
||||
target: test search when collection name is invalid
|
||||
method: call search with invalid collection_name
|
||||
expected: status not ok
|
||||
"""
|
||||
collection_name = get_invalid_collection_name
|
||||
assert not client.search(collection_name, default_query)
|
||||
|
||||
# TODO
|
||||
def test_search_invalid_format_query(self, client, collection):
|
||||
"""
|
||||
target: test search with invalid format query
|
||||
method: call search with invalid query string
|
||||
expected: status not ok and url `/collections/xxx/entities` return correct
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[]], "params": {"nprobe": 10}}}}
|
||||
must_param["vector"][field_name]["metric_type"] = 'L2'
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [must_param]
|
||||
}
|
||||
}
|
||||
assert not client.search(collection, query)
|
||||
|
||||
# TODO
|
||||
def test_search_with_invalid_metric_type(self, client, collection):
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l2")
|
||||
assert not client.search(collection, query)
|
||||
|
||||
# TODO
|
||||
def test_search_binary_flat(self, client, binary_collection):
|
||||
raw_vectors, binary_entities, ids = init_data(client, binary_collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, binary_entities, default_top_k, default_nq)
|
||||
data = client.search(binary_collection, query)
|
||||
res = data['result']
|
||||
assert data['num'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert float(res[0][0]['distance']) <= epsilon
|
||||
assert check_id_result(res, ids[:default_nq])
|
|
@ -0,0 +1,31 @@
|
|||
import pdb
|
||||
import random
|
||||
from locust import User, task, between
|
||||
from locust_task import MilvusTask
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
url = "http://192.168.1.29:19121/"
|
||||
collection_name = "sift_128_euclidean"
|
||||
headers = {'Content-Type': "application/json"}
|
||||
|
||||
default_query, default_query_vectors = gen_query_vectors(default_float_vec_field_name, default_entities, default_top_k, default_nq)
|
||||
|
||||
|
||||
class HttpTest(User):
|
||||
wait_time = between(0, 0.1)
|
||||
client = MilvusTask(url)
|
||||
# client.clear_db()
|
||||
# client.create_collection(collection_name, default_fields)
|
||||
|
||||
# @task
|
||||
# def insert(self):
|
||||
# response = self.client.insert(collection_name, default_entities)
|
||||
|
||||
@task
|
||||
def search(self):
|
||||
response = self.client.search(collection_name, default_query)
|
||||
# res = response['result']
|
||||
# assert response['num'] == default_nq
|
||||
# assert len(res) == default_nq
|
||||
# assert len(res[0]) == default_top_k
|
|
@ -0,0 +1,36 @@
|
|||
import logging
|
||||
import time
|
||||
import random
|
||||
from locust import User, events
|
||||
from client import MilvusClient
|
||||
|
||||
url = 'http://192.168.1.238:19121'
|
||||
|
||||
|
||||
class MilvusTask(object):
|
||||
def __init__(self, url):
|
||||
self.request_type = "http"
|
||||
self.m = MilvusClient(url)
|
||||
# logging.getLogger().info(id(self.m))
|
||||
|
||||
def __getattr__(self, name):
|
||||
func = getattr(self.m, name)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
if result:
|
||||
events.request_success.fire(request_type=self.request_type, name=name, response_time=total_time,
|
||||
response_length=0)
|
||||
else:
|
||||
events.request_failure.fire(request_type=self.request_type, name=name, response_time=total_time,
|
||||
exception=e, response_length=0)
|
||||
|
||||
except Exception as e:
|
||||
total_time = int((time.time() - start_time) * 1000)
|
||||
events.request_failure.fire(request_type=self.request_type, name=name, response_time=total_time,
|
||||
exception=e, response_length=0)
|
||||
|
||||
return wrapper
|
|
@ -0,0 +1,14 @@
|
|||
[pytest]
|
||||
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||
log_date_format = %Y-%m-%d %H:%M:%S
|
||||
|
||||
log_cli = true
|
||||
log_level = 20
|
||||
|
||||
timeout = 360
|
||||
|
||||
markers =
|
||||
level: test level
|
||||
serial
|
||||
|
||||
#level = 1
|
|
@ -0,0 +1,12 @@
|
|||
numpy>=1.18.0
|
||||
pylint==2.5.0
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
pytest-xdist==1.23.2
|
||||
scikit-learn>=0.19.1
|
||||
kubernetes==10.0.1
|
||||
pymilvus-test>=0.4.5
|
|
@ -0,0 +1,25 @@
|
|||
astroid==2.2.5
|
||||
atomicwrites==1.3.0
|
||||
attrs==19.1.0
|
||||
importlib-metadata==0.15
|
||||
isort==4.3.20
|
||||
lazy-object-proxy==1.4.1
|
||||
mccabe==0.6.1
|
||||
more-itertools==7.0.0
|
||||
numpy==1.16.3
|
||||
pluggy==0.12.0
|
||||
py==1.8.0
|
||||
pylint==2.5.0
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
six==1.12.0
|
||||
thrift==0.11.0
|
||||
typed-ast==1.3.5
|
||||
wcwidth==0.1.7
|
||||
wrapt==1.11.1
|
||||
zipp==0.5.1
|
||||
pymilvus>=0.2.0
|
|
@ -0,0 +1,11 @@
|
|||
numpy>=1.18.0
|
||||
pylint==2.5.0
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
pytest-xdist==1.23.2
|
||||
scikit-learn>=0.19.1
|
||||
kubernetes==10.0.1
|
|
@ -0,0 +1,4 @@
|
|||
#/bin/bash
|
||||
|
||||
|
||||
pytest . $@
|
|
@ -0,0 +1,111 @@
|
|||
import logging
|
||||
import time
|
||||
import pdb
|
||||
import copy
|
||||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
BUILD_TIMEOUT = 120
|
||||
uid = "test_index"
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
|
||||
class TestIndexBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `insert` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
)
|
||||
def get_simple_index(self, request, client):
|
||||
if str(client.system_cmd("mode")) == "CPU":
|
||||
if request.param["index_type"] in index_cpu_not_support():
|
||||
pytest.skip("CPU not support index_type: ivf_sq8h")
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_filter_fields()
|
||||
)
|
||||
def get_filter_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_single_vector_fields()
|
||||
)
|
||||
def get_vector_field(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, client, collection):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = client.insert(collection, default_entities)
|
||||
res = client.create_index(collection, default_float_vec_field_name, default_index)
|
||||
# get index info
|
||||
logging.getLogger().info(res)
|
||||
res_info_index = client.describe_index(collection, default_float_vec_field_name)
|
||||
assert res_info_index == default_index
|
||||
|
||||
def test_create_index_on_field_not_existed(self, client, collection):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index on field not existed
|
||||
expected: create failed
|
||||
'''
|
||||
tmp_field_name = gen_unique_str()
|
||||
ids = client.insert(collection, default_entities)
|
||||
assert not client.create_index(collection, tmp_field_name, default_index)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_on_field(self, client, collection):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index on other field
|
||||
expected: create failed
|
||||
'''
|
||||
tmp_field_name = "int64"
|
||||
ids = client.insert(collection, default_entities)
|
||||
assert not client.create_index(collection, tmp_field_name, default_index)
|
||||
|
||||
def test_create_index_collection_not_existed(self, client):
|
||||
'''
|
||||
target: test create index interface when collection name not existed
|
||||
method: create collection and add entities in it, create index
|
||||
, make sure the collection name not in index
|
||||
expected: create index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert not client.create_index(collection_name, default_float_vec_field_name, default_index)
|
||||
|
||||
def test_drop_index(self, client, collection):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create collection and add entities in it, create index, call drop index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
# ids = connect.insert(collection, entities)
|
||||
client.create_index(collection, default_float_vec_field_name, default_index)
|
||||
client.drop_index(collection, default_float_vec_field_name)
|
||||
res_info_index = client.describe_index(collection, default_float_vec_field_name)
|
||||
assert not res_info_index
|
||||
|
||||
def test_drop_index_collection_not_existed(self, client):
|
||||
'''
|
||||
target: test drop index interface when collection name not existed
|
||||
method: create collection and add entities in it, create index
|
||||
, make sure the collection name not in index, and then drop it
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert not client.drop_index(collection_name, default_float_vec_field_name)
|
|
@ -0,0 +1,201 @@
|
|||
import time
|
||||
import random
|
||||
import pdb
|
||||
import threading
|
||||
import logging
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
class TestCreateBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_create_partition(self, client, collection):
|
||||
'''
|
||||
target: test create partition, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
|
||||
def test_create_partition_repeat(self, client, collection):
|
||||
'''
|
||||
target: test create partition, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
ret = client.create_partition(collection, default_tag)
|
||||
assert(not ret)
|
||||
|
||||
def test_create_partition_collection_not_existed(self, client):
|
||||
'''
|
||||
target: test create partition, its owner collection name not existed in db, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str()
|
||||
ret = client.create_partition(collection_name, default_tag)
|
||||
assert(not ret)
|
||||
|
||||
def test_create_partition_tag_name_None(self, client, collection):
|
||||
'''
|
||||
target: test create partition, tag name set None, check status returned
|
||||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
tag_name = None
|
||||
ret = client.create_partition(collection, tag_name)
|
||||
assert(not ret)
|
||||
|
||||
def test_create_different_partition_tags(self, client, collection):
|
||||
'''
|
||||
target: test create partition twice with different names
|
||||
method: call function: create_partition, and again
|
||||
expected: status ok
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
tag_name = gen_unique_str()
|
||||
client.create_partition(collection, tag_name)
|
||||
ret = client.list_partitions(collection)
|
||||
tag_list = []
|
||||
for item in ret:
|
||||
tag_list.append(item['partition_tag'])
|
||||
assert default_tag in tag_list
|
||||
assert tag_name in tag_list
|
||||
assert "_default" in tag_list
|
||||
|
||||
|
||||
class TestShowBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_partitions` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_list_partitions(self, client, collection):
|
||||
'''
|
||||
target: test show partitions, check status and partitions returned
|
||||
method: create partition first, then call function: list_partitions
|
||||
expected: status ok, partition correct
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
ret = client.list_partitions(collection)
|
||||
tag_list = []
|
||||
for item in ret:
|
||||
tag_list.append(item['partition_tag'])
|
||||
assert default_tag in tag_list
|
||||
|
||||
def test_list_partitions_no_partition(self, client, collection):
|
||||
'''
|
||||
target: test show partitions with collection name, check status and partitions returned
|
||||
method: call function: list_partitions
|
||||
expected: status ok, partitions correct
|
||||
'''
|
||||
res = client.list_partitions(collection)
|
||||
assert len(res) == 1
|
||||
|
||||
|
||||
class TestDropBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
def test_drop_partition(self, client, collection):
|
||||
'''
|
||||
target: test drop partition, check status and partition if existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status ok, no partitions in db
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
client.drop_partition(collection, default_tag)
|
||||
res = client.list_partitions(collection)
|
||||
tag_list = []
|
||||
for item in res:
|
||||
tag_list.append(item['partition_tag'])
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_tag_not_existed(self, client, collection):
|
||||
'''
|
||||
target: test drop partition, but tag not existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
new_tag = "new_tag"
|
||||
ret = client.drop_partition(collection, new_tag)
|
||||
assert(not ret)
|
||||
|
||||
def test_drop_partition_tag_not_existed_A(self, client, collection):
|
||||
'''
|
||||
target: test drop partition, but collection not existed
|
||||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
client.create_partition(collection, default_tag)
|
||||
new_collection = gen_unique_str()
|
||||
ret = client.drop_partition(new_collection, default_tag)
|
||||
assert(not ret)
|
||||
|
||||
|
||||
class TestNameInvalid(object):
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
"12-s",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
)
|
||||
def get_tag_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
"12-s",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
)
|
||||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_drop_partition_with_invalid_collection_name(self, client, collection, get_collection_name):
|
||||
'''
|
||||
target: test drop partition, with invalid collection name, check status returned
|
||||
method: call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
client.create_partition(collection, default_tag)
|
||||
ret = client.drop_partition(collection_name, default_tag)
|
||||
assert(not ret)
|
||||
|
||||
def test_drop_partition_with_invalid_tag_name(self, client, collection, get_tag_name):
|
||||
'''
|
||||
target: test drop partition, with invalid tag name, check status returned
|
||||
method: call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
client.create_partition(collection, default_tag)
|
||||
ret = client.drop_partition(collection, tag_name)
|
||||
assert(not ret)
|
||||
|
||||
def test_list_partitions_with_invalid_collection_name(self, client, collection, get_collection_name):
|
||||
'''
|
||||
target: test show partitions, with invalid collection name, check status returned
|
||||
method: call function: list_partitions
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
client.create_partition(collection, default_tag)
|
||||
ret = client.list_partitions(collection_name)
|
||||
assert(not ret)
|
|
@ -0,0 +1,47 @@
|
|||
import logging
|
||||
import pytest
|
||||
import pdb
|
||||
from utils import *
|
||||
|
||||
__version__ = '0.11.0'
|
||||
|
||||
|
||||
class TestPing:
|
||||
def test_server_version(self, client):
|
||||
'''
|
||||
target: test get the server version
|
||||
method: call the server_version method after connected
|
||||
expected: version should be the milvus version
|
||||
'''
|
||||
res = client.system_cmd("version")
|
||||
assert res == __version__
|
||||
|
||||
def test_server_status(self, client):
|
||||
'''
|
||||
target: test get the server status
|
||||
method: call the server_status method after connected
|
||||
expected: status returned should be ok
|
||||
'''
|
||||
msg = client.system_cmd("status")
|
||||
assert msg
|
||||
|
||||
def test_server_cmd_with_params_version(self, client):
|
||||
'''
|
||||
target: test cmd: version
|
||||
method: cmd = "version" ...
|
||||
expected: when cmd = 'version', return version of server;
|
||||
'''
|
||||
cmd = "version"
|
||||
msg = client.system_cmd(cmd)
|
||||
logging.getLogger().info(msg)
|
||||
assert msg == __version__
|
||||
|
||||
def test_server_cmd_with_params_others(self, client):
|
||||
'''
|
||||
target: test cmd: lalala
|
||||
method: cmd = "lalala" ...
|
||||
expected: when cmd = 'version', return version of server;
|
||||
'''
|
||||
cmd = "rm -rf test"
|
||||
msg = client.system_cmd(cmd)
|
||||
assert msg == default_unknown_cmd
|
|
@ -0,0 +1,894 @@
|
|||
import os
|
||||
import sys
|
||||
import random
|
||||
import pdb
|
||||
import string
|
||||
import struct
|
||||
import logging
|
||||
import time, datetime
|
||||
import copy
|
||||
import numpy as np
|
||||
from sklearn import preprocessing
|
||||
from milvus import Milvus, DataType
|
||||
|
||||
port = 19530
|
||||
epsilon = 0.000001
|
||||
namespace = "milvus"
|
||||
|
||||
default_flush_interval = 1
|
||||
big_flush_interval = 1000
|
||||
default_drop_interval = 3
|
||||
default_dim = 128
|
||||
default_nb = 1200
|
||||
default_top_k = 10
|
||||
default_nq = 1
|
||||
max_top_k = 16384
|
||||
max_partition_num = 256
|
||||
default_segment_row_limit = 1000
|
||||
default_server_segment_row_limit = 1024 * 512
|
||||
default_float_vec_field_name = "float_vector"
|
||||
default_binary_vec_field_name = "binary_vector"
|
||||
default_int_field_name = "int64"
|
||||
default_float_field_name = "float"
|
||||
default_double_field_name = "double"
|
||||
|
||||
default_partition_name = "_default"
|
||||
default_tag = "1970_01_01"
|
||||
default_other_fields = ["INT64", "FLOAT"]
|
||||
|
||||
default_unknown_cmd = "Unknown command"
|
||||
|
||||
# TODO:
|
||||
# TODO: disable RHNSW_SQ/PQ in 0.11.0
|
||||
all_index_types = [
|
||||
"FLAT",
|
||||
"IVF_FLAT",
|
||||
"IVF_SQ8",
|
||||
"IVF_SQ8_HYBRID",
|
||||
"IVF_PQ",
|
||||
"HNSW",
|
||||
# "NSG",
|
||||
"ANNOY",
|
||||
# "RHNSW_PQ",
|
||||
# "RHNSW_SQ",
|
||||
"BIN_FLAT",
|
||||
"BIN_IVF_FLAT"
|
||||
]
|
||||
|
||||
default_index_params = [
|
||||
{"nlist": 128},
|
||||
{"nlist": 128},
|
||||
{"nlist": 128},
|
||||
{"nlist": 128},
|
||||
{"nlist": 128, "m": 16},
|
||||
{"M": 48, "efConstruction": 500},
|
||||
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
|
||||
{"n_trees": 50},
|
||||
# {"M": 48, "efConstruction": 500, "PQM": 16},
|
||||
# {"M": 48, "efConstruction": 500},
|
||||
{"nlist": 128},
|
||||
{"nlist": 128}
|
||||
]
|
||||
|
||||
|
||||
def index_cpu_not_support():
|
||||
return ["IVF_SQ8_HYBRID"]
|
||||
|
||||
|
||||
def binary_support():
|
||||
return ["BIN_FLAT", "BIN_IVF_FLAT"]
|
||||
|
||||
|
||||
def delete_support():
|
||||
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
|
||||
|
||||
|
||||
def ivf():
|
||||
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
|
||||
|
||||
|
||||
def binary_metrics():
|
||||
return ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||
|
||||
|
||||
def structure_metrics():
|
||||
return ["SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||
|
||||
|
||||
def l2(x, y):
|
||||
return np.linalg.norm(np.array(x) - np.array(y))
|
||||
|
||||
|
||||
def ip(x, y):
|
||||
return np.inner(np.array(x), np.array(y))
|
||||
|
||||
|
||||
def jaccard(x, y):
|
||||
x = np.asarray(x, np.bool)
|
||||
y = np.asarray(y, np.bool)
|
||||
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())
|
||||
|
||||
|
||||
def hamming(x, y):
|
||||
x = np.asarray(x, np.bool)
|
||||
y = np.asarray(y, np.bool)
|
||||
return np.bitwise_xor(x, y).sum()
|
||||
|
||||
|
||||
def tanimoto(x, y):
|
||||
x = np.asarray(x, np.bool)
|
||||
y = np.asarray(y, np.bool)
|
||||
return -np.log2(np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum()))
|
||||
|
||||
|
||||
def substructure(x, y):
|
||||
x = np.asarray(x, np.bool)
|
||||
y = np.asarray(y, np.bool)
|
||||
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(y)
|
||||
|
||||
|
||||
def superstructure(x, y):
|
||||
x = np.asarray(x, np.bool)
|
||||
y = np.asarray(y, np.bool)
|
||||
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x)
|
||||
|
||||
|
||||
def get_default_field_name(data_type):
|
||||
if data_type == "VECTOR_FLOAT":
|
||||
field_name = default_float_vec_field_name
|
||||
elif data_type == "VECTOR_BINARY":
|
||||
field_name = default_binary_vec_field_name
|
||||
elif data_type == "INT64":
|
||||
field_name = default_int_field_name
|
||||
elif data_type == "FLOAT":
|
||||
field_name = default_float_field_name
|
||||
elif data_type == "DOUBLE":
|
||||
field_name = default_double_field_name
|
||||
return field_name
|
||||
|
||||
|
||||
def reset_build_index_threshold(connect):
|
||||
connect.set_config("engine", "build_index_threshold", 1024)
|
||||
|
||||
|
||||
def disable_flush(connect):
|
||||
connect.set_config("storage", "auto_flush_interval", big_flush_interval)
|
||||
|
||||
|
||||
def enable_flush(connect):
|
||||
# reset auto_flush_interval=1
|
||||
connect.set_config("storage", "auto_flush_interval", default_flush_interval)
|
||||
config_value = connect.get_config("storage", "auto_flush_interval")
|
||||
assert config_value == str(default_flush_interval)
|
||||
|
||||
|
||||
def gen_inaccuracy(num):
|
||||
return num / 255.0
|
||||
|
||||
|
||||
def gen_vectors(num, dim, is_normal=True):
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(num)]
|
||||
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
|
||||
return vectors.tolist()
|
||||
|
||||
|
||||
# def gen_vectors(num, dim, seed=np.random.RandomState(1234), is_normal=False):
|
||||
# xb = seed.rand(num, dim).astype("float32")
|
||||
# xb = preprocessing.normalize(xb, axis=1, norm='l2')
|
||||
# return xb.tolist()
|
||||
|
||||
|
||||
def gen_binary_vectors(num, dim):
|
||||
raw_vectors = []
|
||||
binary_vectors = []
|
||||
for i in range(num):
|
||||
raw_vector = [random.randint(0, 1) for i in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
|
||||
return raw_vectors, binary_vectors
|
||||
|
||||
|
||||
def gen_binary_sub_vectors(vectors, length):
|
||||
raw_vectors = []
|
||||
binary_vectors = []
|
||||
dim = len(vectors[0])
|
||||
for i in range(length):
|
||||
raw_vector = [0 for i in range(dim)]
|
||||
vector = vectors[i]
|
||||
for index, j in enumerate(vector):
|
||||
if j == 1:
|
||||
raw_vector[index] = 1
|
||||
raw_vectors.append(raw_vector)
|
||||
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
|
||||
return raw_vectors, binary_vectors
|
||||
|
||||
|
||||
def gen_binary_super_vectors(vectors, length):
|
||||
raw_vectors = []
|
||||
binary_vectors = []
|
||||
dim = len(vectors[0])
|
||||
for i in range(length):
|
||||
cnt_1 = np.count_nonzero(vectors[i])
|
||||
raw_vector = [1 for i in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
|
||||
return raw_vectors, binary_vectors
|
||||
|
||||
|
||||
def gen_int_attr(row_num):
|
||||
return [random.randint(0, 255) for _ in range(row_num)]
|
||||
|
||||
|
||||
def gen_float_attr(row_num):
|
||||
return [random.uniform(0, 255) for _ in range(row_num)]
|
||||
|
||||
|
||||
def gen_unique_str(str_value=None):
|
||||
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
return "test_" + prefix if str_value is None else str_value + "_" + prefix
|
||||
|
||||
|
||||
def gen_single_filter_fields():
|
||||
fields = []
|
||||
for data_type in DataType:
|
||||
if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
|
||||
fields.append({"name": data_type.name, "type": data_type})
|
||||
return fields
|
||||
|
||||
|
||||
def gen_single_vector_fields():
|
||||
fields = []
|
||||
for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
|
||||
field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}}
|
||||
fields.append(field)
|
||||
return fields
|
||||
|
||||
|
||||
def gen_default_fields(auto_id=True, binary=False):
|
||||
fields = [
|
||||
{"name": default_int_field_name, "type": "INT64"},
|
||||
{"name": default_float_field_name, "type": "FLOAT"}
|
||||
]
|
||||
if binary is False:
|
||||
field = {"name": default_float_vec_field_name, "type": "VECTOR_FLOAT",
|
||||
"params": {"dim": default_dim}}
|
||||
else:
|
||||
field = {"name": default_binary_vec_field_name, "type": "BINARY_FLOAT",
|
||||
"params": {"dim": default_dim}}
|
||||
fields.append(field)
|
||||
default_fields = {
|
||||
"fields": fields,
|
||||
"segment_row_limit": default_segment_row_limit,
|
||||
"auto_id": auto_id
|
||||
}
|
||||
return default_fields
|
||||
|
||||
|
||||
def gen_entities(nb, is_normal=False):
|
||||
entities = []
|
||||
vectors = gen_vectors(nb, default_dim, is_normal)
|
||||
for i in range(nb):
|
||||
entity = {
|
||||
"int64": i,
|
||||
"float": float(i),
|
||||
default_float_vec_field_name: vectors[i]
|
||||
}
|
||||
entities.append(entity)
|
||||
return entities
|
||||
|
||||
|
||||
def gen_binary_entities(nb):
|
||||
raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
|
||||
entities = []
|
||||
for i in range(nb):
|
||||
entity = {
|
||||
default_int_field_name: i,
|
||||
default_float_field_name: float(i),
|
||||
default_binary_vec_field_name: vectors
|
||||
}
|
||||
entities.append(entity)
|
||||
return raw_vectors, entities
|
||||
|
||||
|
||||
def gen_entities_by_fields(fields, nb, dim):
|
||||
entities = []
|
||||
for field in fields:
|
||||
if field["type"] in [DataType.INT32, DataType.INT64]:
|
||||
field_value = [1 for i in range(nb)]
|
||||
elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
|
||||
field_value = [3.0 for i in range(nb)]
|
||||
elif field["type"] == DataType.BINARY_VECTOR:
|
||||
field_value = gen_binary_vectors(nb, dim)[1]
|
||||
elif field["type"] == DataType.FLOAT_VECTOR:
|
||||
field_value = gen_vectors(nb, dim)
|
||||
field.update({"values": field_value})
|
||||
entities.append(field)
|
||||
return entities
|
||||
|
||||
|
||||
def assert_equal_entity(a, b):
|
||||
pass
|
||||
|
||||
|
||||
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
|
||||
metric_type="L2", replace_vecs=None):
|
||||
if rand_vector is True:
|
||||
dimension = len(entities[0][default_float_vec_field_name][0])
|
||||
query_vectors = gen_vectors(nq, dimension)
|
||||
else:
|
||||
query_vectors = list(map(lambda x: x[default_float_vec_field_name], entities[:nq]))
|
||||
if replace_vecs:
|
||||
query_vectors = replace_vecs
|
||||
must_param = {"vector": {field_name: {"topk": top_k, "values": query_vectors, "params": search_params}}}
|
||||
must_param["vector"][field_name]["metric_type"] = metric_type
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [must_param]
|
||||
}
|
||||
}
|
||||
return query, query_vectors
|
||||
|
||||
|
||||
def update_query_expr(src_query, keep_old=True, expr=None):
|
||||
tmp_query = copy.deepcopy(src_query)
|
||||
if expr is not None:
|
||||
tmp_query["bool"].update(expr)
|
||||
if keep_old is not True:
|
||||
tmp_query["bool"].pop("must")
|
||||
return tmp_query
|
||||
|
||||
|
||||
def gen_default_vector_expr(default_query):
|
||||
return default_query["bool"]["must"][0]
|
||||
|
||||
|
||||
def gen_default_term_expr(keyword="term", field="int64", values=None):
|
||||
if values is None:
|
||||
values = [i for i in range(default_nb // 2)]
|
||||
expr = {keyword: {field: {"values": values}}}
|
||||
return expr
|
||||
|
||||
|
||||
def update_term_expr(src_term, terms):
|
||||
tmp_term = copy.deepcopy(src_term)
|
||||
for term in terms:
|
||||
tmp_term["term"].update(term)
|
||||
return tmp_term
|
||||
|
||||
|
||||
def gen_default_range_expr(keyword="range", field="int64", ranges=None):
|
||||
if ranges is None:
|
||||
ranges = {"GT": 1, "LT": default_nb // 2}
|
||||
expr = {keyword: {field: ranges}}
|
||||
return expr
|
||||
|
||||
|
||||
def update_range_expr(src_range, ranges):
|
||||
tmp_range = copy.deepcopy(src_range)
|
||||
for range in ranges:
|
||||
tmp_range["range"].update(range)
|
||||
return tmp_range
|
||||
|
||||
|
||||
def gen_invalid_range():
|
||||
range = [
|
||||
{"range": 1},
|
||||
{"range": {}},
|
||||
{"range": []},
|
||||
{"range": {"range": {"int64": {"GT": 0, "LT": default_nb // 2}}}}
|
||||
]
|
||||
return range
|
||||
|
||||
|
||||
def gen_valid_ranges():
|
||||
ranges = [
|
||||
{"GT": 0, "LT": default_nb // 2},
|
||||
{"GT": default_nb // 2, "LT": default_nb * 2},
|
||||
{"GT": 0},
|
||||
{"LT": default_nb},
|
||||
{"GT": -1, "LT": default_top_k},
|
||||
]
|
||||
return ranges
|
||||
|
||||
|
||||
def gen_invalid_term():
|
||||
terms = [
|
||||
{"term": 1},
|
||||
{"term": []},
|
||||
{"term": {}},
|
||||
{"term": {"term": {"int64": {"values": [i for i in range(default_nb // 2)]}}}}
|
||||
]
|
||||
return terms
|
||||
|
||||
|
||||
def add_field_default(default_fields, type=DataType.INT64, field_name=None):
|
||||
tmp_fields = copy.deepcopy(default_fields)
|
||||
if field_name is None:
|
||||
field_name = gen_unique_str()
|
||||
field = {
|
||||
"name": field_name,
|
||||
"type": type
|
||||
}
|
||||
tmp_fields["fields"].append(field)
|
||||
return tmp_fields
|
||||
|
||||
|
||||
def add_field(entities, field_name=None):
|
||||
nb = len(entities[0]["values"])
|
||||
tmp_entities = copy.deepcopy(entities)
|
||||
if field_name is None:
|
||||
field_name = gen_unique_str()
|
||||
field = {
|
||||
"name": field_name,
|
||||
"type": DataType.INT64,
|
||||
"values": [i for i in range(nb)]
|
||||
}
|
||||
tmp_entities.append(field)
|
||||
return tmp_entities
|
||||
|
||||
|
||||
def add_vector_field(entities, is_normal=False):
|
||||
nb = len(entities[0]["values"])
|
||||
vectors = gen_vectors(nb, default_dim, is_normal)
|
||||
field = {
|
||||
"name": gen_unique_str(),
|
||||
"type": DataType.FLOAT_VECTOR,
|
||||
"values": vectors
|
||||
}
|
||||
entities.append(field)
|
||||
return entities
|
||||
|
||||
|
||||
# def update_fields_metric_type(fields, metric_type):
|
||||
# tmp_fields = copy.deepcopy(fields)
|
||||
# if metric_type in ["L2", "IP"]:
|
||||
# tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR
|
||||
# else:
|
||||
# tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR
|
||||
# tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type
|
||||
# return tmp_fields
|
||||
|
||||
|
||||
def remove_field(entities):
|
||||
del entities[0]
|
||||
return entities
|
||||
|
||||
|
||||
def remove_vector_field(entities):
|
||||
del entities[-1]
|
||||
return entities
|
||||
|
||||
|
||||
def update_field_name(entities, old_name, new_name):
|
||||
tmp_entities = copy.deepcopy(entities)
|
||||
for item in tmp_entities:
|
||||
if item["name"] == old_name:
|
||||
item["name"] = new_name
|
||||
return tmp_entities
|
||||
|
||||
|
||||
def update_field_type(entities, old_name, new_name):
|
||||
tmp_entities = copy.deepcopy(entities)
|
||||
for item in tmp_entities:
|
||||
if item["name"] == old_name:
|
||||
item["type"] = new_name
|
||||
return tmp_entities
|
||||
|
||||
|
||||
def update_field_value(entities, old_type, new_value):
|
||||
tmp_entities = copy.deepcopy(entities)
|
||||
for item in tmp_entities:
|
||||
if item["type"] == old_type:
|
||||
for index, value in enumerate(item["values"]):
|
||||
item["values"][index] = new_value
|
||||
return tmp_entities
|
||||
|
||||
|
||||
def add_vector_field(nb, dimension=default_dim):
|
||||
field_name = gen_unique_str()
|
||||
field = {
|
||||
"name": field_name,
|
||||
"type": DataType.FLOAT_VECTOR,
|
||||
"values": gen_vectors(nb, dimension)
|
||||
}
|
||||
return field_name
|
||||
|
||||
|
||||
def gen_segment_row_limits():
|
||||
sizes = [
|
||||
1024,
|
||||
4096
|
||||
]
|
||||
return sizes
|
||||
|
||||
|
||||
def gen_invalid_ips():
|
||||
ips = [
|
||||
# "255.0.0.0",
|
||||
# "255.255.0.0",
|
||||
# "255.255.255.0",
|
||||
# "255.255.255.255",
|
||||
"127.0.0",
|
||||
# "123.0.0.2",
|
||||
"12-s",
|
||||
" ",
|
||||
"12 s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for _ in range(256))
|
||||
]
|
||||
return ips
|
||||
|
||||
|
||||
def gen_invalid_uris():
|
||||
ip = None
|
||||
uris = [
|
||||
" ",
|
||||
"中文",
|
||||
# invalid protocol
|
||||
# "tc://%s:%s" % (ip, port),
|
||||
# "tcp%s:%s" % (ip, port),
|
||||
|
||||
# # invalid port
|
||||
# "tcp://%s:100000" % ip,
|
||||
# "tcp://%s: " % ip,
|
||||
# "tcp://%s:19540" % ip,
|
||||
# "tcp://%s:-1" % ip,
|
||||
# "tcp://%s:string" % ip,
|
||||
|
||||
# invalid ip
|
||||
"tcp:// :19530",
|
||||
# "tcp://123.0.0.1:%s" % port,
|
||||
"tcp://127.0.0:19530",
|
||||
# "tcp://255.0.0.0:%s" % port,
|
||||
# "tcp://255.255.0.0:%s" % port,
|
||||
# "tcp://255.255.255.0:%s" % port,
|
||||
# "tcp://255.255.255.255:%s" % port,
|
||||
"tcp://\n:19530",
|
||||
]
|
||||
return uris
|
||||
|
||||
|
||||
def gen_invalid_strs():
|
||||
strings = [
|
||||
1,
|
||||
[1],
|
||||
None,
|
||||
"12-s",
|
||||
" ",
|
||||
# "",
|
||||
# None,
|
||||
"12 s",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return strings
|
||||
|
||||
|
||||
def gen_invalid_field_types():
|
||||
field_types = [
|
||||
# 1,
|
||||
"=c",
|
||||
# 0,
|
||||
None,
|
||||
"",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return field_types
|
||||
|
||||
|
||||
def gen_invalid_metric_types():
|
||||
metric_types = [
|
||||
1,
|
||||
"=c",
|
||||
0,
|
||||
None,
|
||||
"",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return metric_types
|
||||
|
||||
|
||||
# TODO:
|
||||
def gen_invalid_ints():
|
||||
int_values = [
|
||||
# 1.0,
|
||||
None,
|
||||
[1, 2, 3],
|
||||
" ",
|
||||
"",
|
||||
-1,
|
||||
"String",
|
||||
"=c",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return int_values
|
||||
|
||||
|
||||
def gen_invalid_params():
|
||||
params = [
|
||||
9999999999,
|
||||
-1,
|
||||
# None,
|
||||
[1, 2, 3],
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"中文"
|
||||
]
|
||||
return params
|
||||
|
||||
|
||||
def gen_invalid_vectors():
|
||||
invalid_vectors = [
|
||||
"1*2",
|
||||
[],
|
||||
[1],
|
||||
[1, 2],
|
||||
[" "],
|
||||
['a'],
|
||||
[None],
|
||||
None,
|
||||
(1, 2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
" siede ",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return invalid_vectors
|
||||
|
||||
|
||||
def gen_invaild_search_params():
|
||||
invalid_search_key = 100
|
||||
search_params = []
|
||||
for index_type in all_index_types:
|
||||
if index_type == "FLAT":
|
||||
continue
|
||||
search_params.append({"index_type": index_type, "search_params": {"invalid_key": invalid_search_key}})
|
||||
if index_type in delete_support():
|
||||
for nprobe in gen_invalid_params():
|
||||
ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
|
||||
search_params.append(ivf_search_params)
|
||||
elif index_type in ["HNSW", "RHNSW_PQ", "RHNSW_SQ"]:
|
||||
for ef in gen_invalid_params():
|
||||
hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
|
||||
search_params.append(hnsw_search_param)
|
||||
elif index_type == "NSG":
|
||||
for search_length in gen_invalid_params():
|
||||
nsg_search_param = {"index_type": index_type, "search_params": {"search_length": search_length}}
|
||||
search_params.append(nsg_search_param)
|
||||
search_params.append({"index_type": index_type, "search_params": {"invalid_key": 100}})
|
||||
elif index_type == "ANNOY":
|
||||
for search_k in gen_invalid_params():
|
||||
if isinstance(search_k, int):
|
||||
continue
|
||||
annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}}
|
||||
search_params.append(annoy_search_param)
|
||||
return search_params
|
||||
|
||||
|
||||
def gen_invalid_index():
|
||||
index_params = []
|
||||
for index_type in gen_invalid_strs():
|
||||
index_param = {"index_type": index_type, "params": {"nlist": 1024}}
|
||||
index_params.append(index_param)
|
||||
for nlist in gen_invalid_params():
|
||||
index_param = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}}
|
||||
index_params.append(index_param)
|
||||
for M in gen_invalid_params():
|
||||
index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
|
||||
index_param = {"index_type": "RHNSW_PQ", "params": {"M": M, "efConstruction": 100}}
|
||||
index_param = {"index_type": "RHNSW_SQ", "params": {"M": M, "efConstruction": 100}}
|
||||
index_params.append(index_param)
|
||||
for efConstruction in gen_invalid_params():
|
||||
index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
|
||||
index_param = {"index_type": "RHNSW_PQ", "params": {"M": 16, "efConstruction": efConstruction}}
|
||||
index_param = {"index_type": "RHNSW_SQ", "params": {"M": 16, "efConstruction": efConstruction}}
|
||||
index_params.append(index_param)
|
||||
for search_length in gen_invalid_params():
|
||||
index_param = {"index_type": "NSG",
|
||||
"params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50,
|
||||
"knng": 100}}
|
||||
index_params.append(index_param)
|
||||
for out_degree in gen_invalid_params():
|
||||
index_param = {"index_type": "NSG",
|
||||
"params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50,
|
||||
"knng": 100}}
|
||||
index_params.append(index_param)
|
||||
for candidate_pool_size in gen_invalid_params():
|
||||
index_param = {"index_type": "NSG", "params": {"search_length": 100, "out_degree": 40,
|
||||
"candidate_pool_size": candidate_pool_size,
|
||||
"knng": 100}}
|
||||
index_params.append(index_param)
|
||||
index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
|
||||
index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
|
||||
index_params.append({"index_type": "RHNSW_PQ", "params": {"invalid_key": 16, "efConstruction": 100}})
|
||||
index_params.append({"index_type": "RHNSW_SQ", "params": {"invalid_key": 16, "efConstruction": 100}})
|
||||
index_params.append({"index_type": "NSG",
|
||||
"params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
|
||||
"knng": 100}})
|
||||
for invalid_n_trees in gen_invalid_params():
|
||||
index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}})
|
||||
|
||||
return index_params
|
||||
|
||||
|
||||
def gen_index():
|
||||
nlists = [1, 1024, 16384]
|
||||
pq_ms = [128, 64, 32, 16, 8, 4]
|
||||
Ms = [5, 24, 48]
|
||||
efConstructions = [100, 300, 500]
|
||||
search_lengths = [10, 100, 300]
|
||||
out_degrees = [5, 40, 300]
|
||||
candidate_pool_sizes = [50, 100, 300]
|
||||
knngs = [5, 100, 300]
|
||||
|
||||
index_params = []
|
||||
for index_type in all_index_types:
|
||||
if index_type in ["FLAT", "BIN_FLAT", "BIN_IVF_FLAT"]:
|
||||
index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
|
||||
elif index_type in ["IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID"]:
|
||||
ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \
|
||||
for nlist in nlists]
|
||||
index_params.extend(ivf_params)
|
||||
elif index_type == "IVF_PQ":
|
||||
IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
|
||||
for nlist in nlists \
|
||||
for m in pq_ms]
|
||||
index_params.extend(IVFPQ_params)
|
||||
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
|
||||
hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
|
||||
for M in Ms \
|
||||
for efConstruction in efConstructions]
|
||||
index_params.extend(hnsw_params)
|
||||
elif index_type == "NSG":
|
||||
nsg_params = [{"index_type": index_type,
|
||||
"index_param": {"search_length": search_length, "out_degree": out_degree,
|
||||
"candidate_pool_size": candidate_pool_size, "knng": knng}} \
|
||||
for search_length in search_lengths \
|
||||
for out_degree in out_degrees \
|
||||
for candidate_pool_size in candidate_pool_sizes \
|
||||
for knng in knngs]
|
||||
index_params.extend(nsg_params)
|
||||
|
||||
return index_params
|
||||
|
||||
|
||||
def gen_simple_index():
|
||||
index_params = []
|
||||
for i in range(len(all_index_types)):
|
||||
if all_index_types[i] in binary_support():
|
||||
continue
|
||||
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
|
||||
dic.update({"params": default_index_params[i]})
|
||||
index_params.append(dic)
|
||||
return index_params
|
||||
|
||||
|
||||
def gen_binary_index():
|
||||
index_params = []
|
||||
for i in range(len(all_index_types)):
|
||||
if all_index_types[i] in binary_support():
|
||||
dic = {"index_type": all_index_types[i]}
|
||||
dic.update({"params": default_index_params[i]})
|
||||
index_params.append(dic)
|
||||
return index_params
|
||||
|
||||
|
||||
def get_search_param(index_type, metric_type="L2"):
|
||||
search_params = {"metric_type": metric_type}
|
||||
if index_type in ivf() or index_type in binary_support():
|
||||
search_params.update({"nprobe": 64})
|
||||
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
|
||||
search_params.update({"ef": 64})
|
||||
elif index_type == "NSG":
|
||||
search_params.update({"search_length": 100})
|
||||
elif index_type == "ANNOY":
|
||||
search_params.update({"search_k": 1000})
|
||||
else:
|
||||
logging.getLogger().error("Invalid index_type.")
|
||||
raise Exception("Invalid index_type.")
|
||||
return search_params
|
||||
|
||||
|
||||
def assert_equal_vector(v1, v2):
|
||||
if len(v1) != len(v2):
|
||||
assert False
|
||||
for i in range(len(v1)):
|
||||
assert abs(v1[i] - v2[i]) < epsilon
|
||||
|
||||
|
||||
def restart_server(helm_release_name):
|
||||
res = True
|
||||
timeout = 120
|
||||
from kubernetes import client, config
|
||||
client.rest.logger.setLevel(logging.WARNING)
|
||||
|
||||
# service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
|
||||
config.load_kube_config()
|
||||
v1 = client.CoreV1Api()
|
||||
pod_name = None
|
||||
# config_map_names = v1.list_namespaced_config_map(namespace, pretty='true')
|
||||
# body = {"replicas": 0}
|
||||
pods = v1.list_namespaced_pod(namespace)
|
||||
for i in pods.items:
|
||||
if i.metadata.name.find(helm_release_name) != -1 and i.metadata.name.find("mysql") == -1:
|
||||
pod_name = i.metadata.name
|
||||
break
|
||||
# v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
|
||||
# status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
|
||||
logging.getLogger().debug("Pod name: %s" % pod_name)
|
||||
if pod_name is not None:
|
||||
try:
|
||||
v1.delete_namespaced_pod(pod_name, namespace)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
|
||||
res = False
|
||||
return res
|
||||
logging.error("Sleep 10s after pod deleted")
|
||||
time.sleep(10)
|
||||
# check if restart successfully
|
||||
pods = v1.list_namespaced_pod(namespace)
|
||||
for i in pods.items:
|
||||
pod_name_tmp = i.metadata.name
|
||||
logging.error(pod_name_tmp)
|
||||
if pod_name_tmp == pod_name:
|
||||
continue
|
||||
elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1:
|
||||
continue
|
||||
else:
|
||||
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
|
||||
logging.error(status_res.status.phase)
|
||||
start_time = time.time()
|
||||
ready_break = False
|
||||
while time.time() - start_time <= timeout:
|
||||
logging.error(time.time())
|
||||
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
|
||||
if status_res.status.phase == "Running":
|
||||
logging.error("Already running")
|
||||
ready_break = True
|
||||
time.sleep(10)
|
||||
break
|
||||
else:
|
||||
time.sleep(1)
|
||||
if time.time() - start_time > timeout:
|
||||
logging.error("Restart pod: %s timeout" % pod_name_tmp)
|
||||
res = False
|
||||
return res
|
||||
if ready_break:
|
||||
break
|
||||
else:
|
||||
raise Exception("Pod: %s not found" % pod_name)
|
||||
follow = True
|
||||
pretty = True
|
||||
previous = True # bool | Return previous terminated container logs. Defaults to false. (optional)
|
||||
since_seconds = 56 # int | A relative time in seconds before the current time from which to show logs. If this value precedes the time a pod was started, only logs since the pod start will be returned. If this value is in the future, no logs will be returned. Only one of sinceSeconds or sinceTime may be specified. (optional)
|
||||
timestamps = True # bool | If true, add an RFC3339 or RFC3339Nano timestamp at the beginning of every line of log output. Defaults to false. (optional)
|
||||
container = "milvus"
|
||||
# start_time = time.time()
|
||||
# while time.time() - start_time <= timeout:
|
||||
# try:
|
||||
# api_response = v1.read_namespaced_pod_log(pod_name_tmp, namespace, container=container, follow=follow,
|
||||
# pretty=pretty, previous=previous, since_seconds=since_seconds,
|
||||
# timestamps=timestamps)
|
||||
# logging.error(api_response)
|
||||
# return res
|
||||
# except Exception as e:
|
||||
# logging.error("Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
|
||||
# # waiting for server start
|
||||
# time.sleep(5)
|
||||
# # res = False
|
||||
# # return res
|
||||
# if time.time() - start_time > timeout:
|
||||
# logging.error("Restart pod: %s timeout" % pod_name_tmp)
|
||||
# res = False
|
||||
return res
|
|
@ -1,32 +1,62 @@
|
|||
# Requirements
|
||||
## Requirements
|
||||
* python 3.6.8+
|
||||
* pip install -r requirements.txt
|
||||
|
||||
# How to use this Test Project
|
||||
## How to Build Test Env
|
||||
```shell
|
||||
pytest . --level=1
|
||||
sudo docker pull registry.zilliz.com/milvus/milvus-test:v0.2
|
||||
sudo docker run -it -v /home/zilliz:/home/zilliz -d registry.zilliz.com/milvus/milvus-test:v0.2
|
||||
```
|
||||
or test connect function only
|
||||
|
||||
## How to Create Test Env docker in k8s
|
||||
```shell
|
||||
# 1. start milvus k8s pod
|
||||
cd milvus-helm/charts/milvus
|
||||
helm install --wait --timeout 300s \
|
||||
--set image.repository=registry.zilliz.com/milvus/engine \
|
||||
--set persistence.enabled=true \
|
||||
--set image.tag=PR-3818-gpu-centos7-release \
|
||||
--set image.pullPolicy=Always \
|
||||
--set service.type=LoadBalancer \
|
||||
-f ci/db_backend/mysql_gpu_values.yaml \
|
||||
-f ci/filebeat/values.yaml \
|
||||
-f test.yaml \
|
||||
--namespace milvus \
|
||||
milvus-ci-pr-3818-1-single-centos7-gpu .
|
||||
|
||||
# 2. remove milvus k8s pod
|
||||
helm uninstall -n milvus milvus-test
|
||||
|
||||
# 3. check k8s pod status
|
||||
kubectl get svc -n milvus -w milvus-test
|
||||
|
||||
# 4. login to pod
|
||||
kubectl get pods --namespace milvus
|
||||
kubectl exec -it milvus-test-writable-6cc49cfcd4-rbrns -n milvus bash
|
||||
```
|
||||
|
||||
## How to Run Test cases
|
||||
```shell
|
||||
# Test level-1 cases
|
||||
pytest . --level=1 --ip=127.0.0.1 --port=19530
|
||||
|
||||
# Test level-1 cases in 'test_connect.py' only
|
||||
pytest test_connect.py --level=1
|
||||
```
|
||||
|
||||
collect cases
|
||||
```shell
|
||||
pytest --day-run -qq
|
||||
```
|
||||
collect cases with docstring
|
||||
## How to list test cases
|
||||
```shell
|
||||
# List all cases
|
||||
pytest --dry-run -qq
|
||||
|
||||
# Collect all cases with docstring
|
||||
pytest --collect-only -qq
|
||||
```
|
||||
|
||||
with allure test report
|
||||
|
||||
```shell
|
||||
# Create test report with allure
|
||||
pytest --alluredir=test_out . -q -v
|
||||
allure serve test_out
|
||||
```
|
||||
# Contribution getting started
|
||||
|
||||
## Contribution getting started
|
||||
* Follow PEP-8 for naming and black for formatting.
|
||||
|
||||
|
|
|
@ -9,20 +9,10 @@ import sklearn.preprocessing
|
|||
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
nb = 6000
|
||||
dim = 128
|
||||
tag = "tag"
|
||||
collection_id = "count_collection"
|
||||
add_interval_time = 3
|
||||
segment_row_count = 5000
|
||||
default_fields = gen_default_fields()
|
||||
default_binary_fields = gen_binary_default_fields()
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
field_name = "fload_vector"
|
||||
index_name = "index_name"
|
||||
|
||||
uid = "collection_count"
|
||||
tag = "collection_count_tag"
|
||||
|
||||
class TestCollectionCount:
|
||||
"""
|
||||
|
@ -32,8 +22,8 @@ class TestCollectionCount:
|
|||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
4000,
|
||||
6001
|
||||
1000,
|
||||
2001
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
|
@ -155,7 +145,7 @@ class TestCollectionCount:
|
|||
entities = gen_entities(insert_count)
|
||||
res = connect.insert(collection, entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
res = connect.count_entities(collection)
|
||||
assert res == insert_count
|
||||
|
||||
|
@ -187,8 +177,8 @@ class TestCollectionCountIP:
|
|||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
4000,
|
||||
6001
|
||||
1000,
|
||||
2001
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
|
@ -230,8 +220,8 @@ class TestCollectionCountBinary:
|
|||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
4000,
|
||||
6001
|
||||
1000,
|
||||
2001
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
|
@ -423,8 +413,8 @@ class TestCollectionMultiCollections:
|
|||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
4000,
|
||||
6001
|
||||
1000,
|
||||
2001
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
|
@ -485,7 +475,7 @@ class TestCollectionMultiCollections:
|
|||
collection_list = []
|
||||
collection_num = 20
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.insert(collection_name, entities)
|
||||
|
@ -507,7 +497,7 @@ class TestCollectionMultiCollections:
|
|||
collection_list = []
|
||||
collection_num = 20
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_binary_fields)
|
||||
res = connect.insert(collection_name, entities)
|
||||
|
@ -527,16 +517,16 @@ class TestCollectionMultiCollections:
|
|||
collection_list = []
|
||||
collection_num = 20
|
||||
for i in range(0, int(collection_num / 2)):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.insert(collection_name, entities)
|
||||
res = connect.insert(collection_name, default_entities)
|
||||
for i in range(int(collection_num / 2), collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_binary_fields)
|
||||
res = connect.insert(collection_name, binary_entities)
|
||||
res = connect.insert(collection_name, default_binary_entities)
|
||||
connect.flush(collection_list)
|
||||
for i in range(collection_num):
|
||||
res = connect.count_entities(collection_list[i])
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
|
|
|
@ -6,17 +6,10 @@ from time import sleep
|
|||
from multiprocessing import Process
|
||||
from utils import *
|
||||
|
||||
dim = 128
|
||||
default_segment_row_count = 100000
|
||||
drop_collection_interval_time = 3
|
||||
segment_row_count = 5000
|
||||
collection_id = "logic"
|
||||
vectors = gen_vectors(100, dim)
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
uid = "collection_logic"
|
||||
|
||||
def create_collection(connect, **params):
|
||||
connect.create_collection(params["collection_name"], default_fields)
|
||||
connect.create_collection(params["collection_name"], const.default_fields)
|
||||
|
||||
def search_collection(connect, **params):
|
||||
status, result = connect.search(
|
||||
|
@ -129,7 +122,7 @@ class TestCollectionLogic(object):
|
|||
connect.drop_collection(name)
|
||||
|
||||
def gen_params(self):
|
||||
collection_name = gen_unique_str("collection_id")
|
||||
collection_name = gen_unique_str(uid)
|
||||
top_k = 1
|
||||
vectors = gen_vectors(2, dim)
|
||||
param = {'collection_name': collection_name,
|
||||
|
|
|
@ -3,25 +3,12 @@ import pdb
|
|||
import threading
|
||||
import logging
|
||||
from multiprocessing import Pool, Process
|
||||
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
nprobe = 1
|
||||
top_k = 1
|
||||
epsilon = 0.0001
|
||||
tag = "1970_01_01"
|
||||
nb = 6000
|
||||
nlist = 1024
|
||||
collection_id = "collection_stats"
|
||||
field_name = "float_vector"
|
||||
entity = gen_entities(1)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
uid = "collection_stats"
|
||||
|
||||
class TestStatsBase:
|
||||
"""
|
||||
|
@ -65,10 +52,11 @@ class TestStatsBase:
|
|||
method: call collection_stats with a random collection_name, which is not in db
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
stats = connect.get_collection_stats(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_collection_stats_name_invalid(self, connect, get_collection_name):
|
||||
'''
|
||||
target: get collection stats where collection name is invalid
|
||||
|
@ -88,7 +76,7 @@ class TestStatsBase:
|
|||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == 0
|
||||
assert len(stats["partitions"]) == 1
|
||||
assert stats["partitions"][0]["tag"] == "_default"
|
||||
assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
assert stats["partitions"][0]["row_count"] == 0
|
||||
|
||||
def test_get_collection_stats_batch(self, connect, collection):
|
||||
|
@ -97,13 +85,13 @@ class TestStatsBase:
|
|||
method: add entities, check count in collection info
|
||||
expected: count as expected
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
assert len(stats["partitions"]) == 1
|
||||
assert stats["partitions"][0]["tag"] == "_default"
|
||||
assert stats["partitions"][0]["row_count"] == nb
|
||||
assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
assert stats["partitions"][0]["row_count"] == default_nb
|
||||
|
||||
def test_get_collection_stats_single(self, connect, collection):
|
||||
'''
|
||||
|
@ -113,12 +101,12 @@ class TestStatsBase:
|
|||
'''
|
||||
nb = 10
|
||||
for i in range(nb):
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert len(stats["partitions"]) == 1
|
||||
assert stats["partitions"][0]["tag"] == "_default"
|
||||
assert stats["partitions"][0]["tag"] == default_partition_name
|
||||
assert stats["partitions"][0]["row_count"] == nb
|
||||
|
||||
def test_get_collection_stats_after_delete(self, connect, collection):
|
||||
|
@ -127,14 +115,14 @@ class TestStatsBase:
|
|||
method: add and delete entities, check count in collection info
|
||||
expected: status ok, count as expected
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
status = connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
connect.delete_entity_by_id(collection, delete_ids)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb - 2
|
||||
assert stats["partitions"][0]["row_count"] == nb -2
|
||||
assert stats["row_count"] == default_nb - 2
|
||||
assert stats["partitions"][0]["row_count"] == default_nb - 2
|
||||
assert stats["partitions"][0]["segments"][0]["data_size"] > 0
|
||||
|
||||
# TODO: enable
|
||||
|
@ -145,20 +133,21 @@ class TestStatsBase:
|
|||
method: add and delete entities, and compact collection, check count in collection info
|
||||
expected: status ok, count as expected
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
delete_length = 1000
|
||||
ids = connect.insert(collection, default_entities)
|
||||
status = connect.flush([collection])
|
||||
delete_ids = ids[:3000]
|
||||
delete_ids = ids[:delete_length]
|
||||
connect.delete_entity_by_id(collection, delete_ids)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(stats)
|
||||
assert stats["row_count"] == nb - 3000
|
||||
assert stats["row_count"] == default_nb - delete_length
|
||||
compact_before = stats["partitions"][0]["segments"][0]["data_size"]
|
||||
connect.compact(collection)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(stats)
|
||||
compact_after = stats["partitions"][0]["segments"][0]["data_size"]
|
||||
assert compact_before > compact_after
|
||||
assert compact_before == compact_after
|
||||
|
||||
def test_get_collection_stats_after_compact_delete_one(self, connect, collection):
|
||||
'''
|
||||
|
@ -166,7 +155,7 @@ class TestStatsBase:
|
|||
method: add and delete one entity, and compact collection, check count in collection info
|
||||
expected: status ok, count as expected
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
status = connect.flush([collection])
|
||||
delete_ids = ids[:1]
|
||||
connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -187,13 +176,13 @@ class TestStatsBase:
|
|||
method: call collection_stats after partition created and check partition_stats
|
||||
expected: status ok, vectors added to partition
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert len(stats["partitions"]) == 2
|
||||
assert stats["partitions"][1]["tag"] == tag
|
||||
assert stats["partitions"][1]["row_count"] == nb
|
||||
assert stats["partitions"][1]["tag"] == default_tag
|
||||
assert stats["partitions"][1]["row_count"] == default_nb
|
||||
|
||||
def test_get_collection_stats_partitions(self, connect, collection):
|
||||
'''
|
||||
|
@ -202,22 +191,22 @@ class TestStatsBase:
|
|||
expected: status ok, vectors added to one partition but not the other
|
||||
'''
|
||||
new_tag = "new_tag"
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
for partition in stats["partitions"]:
|
||||
if partition["tag"] == tag:
|
||||
assert partition["row_count"] == nb
|
||||
if partition["tag"] == default_tag:
|
||||
assert partition["row_count"] == default_nb
|
||||
else:
|
||||
assert partition["row_count"] == 0
|
||||
ids = connect.insert(collection, entities, partition_tag=new_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=new_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
for partition in stats["partitions"]:
|
||||
if partition["tag"] in [tag, new_tag]:
|
||||
assert partition["row_count"] == nb
|
||||
if partition["tag"] in [default_tag, new_tag]:
|
||||
assert partition["row_count"] == default_nb
|
||||
|
||||
def test_get_collection_stats_after_index_created(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
@ -225,15 +214,17 @@ class TestStatsBase:
|
|||
method: create collection, add vectors, create index and call collection_stats
|
||||
expected: status ok, index created and shown in segments
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb
|
||||
logging.getLogger().info(stats)
|
||||
assert stats["row_count"] == default_nb
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
assert file["data_size"] > 0
|
||||
assert file["index_type"] == get_simple_index["index_type"]
|
||||
break
|
||||
|
||||
def test_get_collection_stats_after_index_created_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
@ -242,16 +233,17 @@ class TestStatsBase:
|
|||
expected: status ok, index created and shown in segments
|
||||
'''
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_simple_index.update({"metric_type": "IP"})
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
assert file["data_size"] > 0
|
||||
assert file["index_type"] == get_simple_index["index_type"]
|
||||
break
|
||||
|
||||
def test_get_collection_stats_after_index_created_jac(self, connect, binary_collection, get_jaccard_index):
|
||||
'''
|
||||
|
@ -259,15 +251,16 @@ class TestStatsBase:
|
|||
method: create collection, add binary entities, create index and call collection_stats
|
||||
expected: status ok, index created and shown in segments
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, "binary_vector", get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
assert file["data_size"] > 0
|
||||
assert file["index_type"] == get_simple_index["index_type"]
|
||||
break
|
||||
|
||||
def test_get_collection_stats_after_create_different_index(self, connect, collection):
|
||||
'''
|
||||
|
@ -275,16 +268,18 @@ class TestStatsBase:
|
|||
method: create collection, add vectors, create index and call collection_stats multiple times
|
||||
expected: status ok, index info shown in segments
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
for index_type in ["IVF_FLAT", "IVF_SQ8"]:
|
||||
connect.create_index(collection, field_name, {"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
connect.create_index(collection, default_float_vec_field_name,
|
||||
{"index_type": index_type, "params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
assert file["data_size"] > 0
|
||||
assert file["index_type"] == index_type
|
||||
break
|
||||
|
||||
def test_collection_count_multi_collections(self, connect):
|
||||
'''
|
||||
|
@ -296,14 +291,14 @@ class TestStatsBase:
|
|||
collection_list = []
|
||||
collection_num = 10
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.insert(collection_name, entities)
|
||||
res = connect.insert(collection_name, default_entities)
|
||||
connect.flush(collection_list)
|
||||
for i in range(collection_num):
|
||||
stats = connect.get_collection_stats(collection_list[i])
|
||||
assert stats["partitions"][0]["row_count"] == nb
|
||||
assert stats["partitions"][0]["row_count"] == default_nb
|
||||
connect.drop_collection(collection_list[i])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -317,23 +312,27 @@ class TestStatsBase:
|
|||
collection_list = []
|
||||
collection_num = 10
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.insert(collection_name, entities)
|
||||
res = connect.insert(collection_name, default_entities)
|
||||
connect.flush(collection_list)
|
||||
if i % 2:
|
||||
connect.create_index(collection_name, field_name, {"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
connect.create_index(collection_name, default_float_vec_field_name,
|
||||
{"index_type": "IVF_SQ8", "params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
else:
|
||||
connect.create_index(collection_name, field_name, {"index_type": "IVF_FLAT","params":{ "nlist": 1024}, "metric_type": "L2"})
|
||||
connect.create_index(collection_name, default_float_vec_field_name,
|
||||
{"index_type": "IVF_FLAT","params":{"nlist": 1024}, "metric_type": "L2"})
|
||||
for i in range(collection_num):
|
||||
stats = connect.get_collection_stats(collection_list[i])
|
||||
if i % 2:
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
assert file["index_type"] == "IVF_SQ8"
|
||||
break
|
||||
else:
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
if file["name"] == default_float_vec_field_name and "index_type" in file:
|
||||
assert file["index_type"] == "IVF_FLAT"
|
||||
break
|
||||
connect.drop_collection(collection_list[i])
|
||||
|
|
|
@ -9,18 +9,11 @@ import sklearn.preprocessing
|
|||
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
nb = 1
|
||||
dim = 128
|
||||
collection_id = "create_collection"
|
||||
default_segment_row_count = 512 * 1024
|
||||
drop_collection_interval_time = 3
|
||||
segment_row_count = 5000
|
||||
default_fields = gen_default_fields()
|
||||
entities = gen_entities(nb)
|
||||
uid = "create_collection"
|
||||
|
||||
class TestCreateCollection:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_collection` function
|
||||
|
@ -42,9 +35,9 @@ class TestCreateCollection:
|
|||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_counts()
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_count(self, request):
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
|
||||
|
@ -56,10 +49,10 @@ class TestCreateCollection:
|
|||
filter_field = get_filter_field
|
||||
logging.getLogger().info(filter_field)
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
logging.getLogger().info(fields)
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
@ -73,23 +66,23 @@ class TestCreateCollection:
|
|||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
|
||||
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_count
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_count
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
assert connect.has_collection(collection_name)
|
||||
|
||||
|
@ -101,7 +94,7 @@ class TestCreateCollection:
|
|||
expected: create status return ok
|
||||
'''
|
||||
disable_flush(connect)
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
try:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
finally:
|
||||
|
@ -115,7 +108,7 @@ class TestCreateCollection:
|
|||
expected: error raised
|
||||
'''
|
||||
# pdb.set_trace()
|
||||
connect.insert(collection, entities)
|
||||
connect.insert(collection, default_entity)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
@ -126,7 +119,7 @@ class TestCreateCollection:
|
|||
method: insert vector and create collection
|
||||
expected: error raised
|
||||
'''
|
||||
connect.insert(collection, entities)
|
||||
connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
@ -136,9 +129,9 @@ class TestCreateCollection:
|
|||
'''
|
||||
target: test create collection, without connection
|
||||
method: create collection with correct params, with a disconnected instance
|
||||
expected: create raise exception
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
|
@ -146,13 +139,23 @@ class TestCreateCollection:
|
|||
'''
|
||||
target: test create collection but the collection name have already existed
|
||||
method: create collection with the same collection_name
|
||||
expected: create status return not ok
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def test_create_after_drop_collection(self, connect, collection):
|
||||
'''
|
||||
target: create with the same collection name after collection dropped
|
||||
method: delete, then create
|
||||
expected: create success
|
||||
'''
|
||||
connect.drop_collection(collection)
|
||||
time.sleep(2)
|
||||
connect.create_collection(collection, default_fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_multithread(self, connect):
|
||||
'''
|
||||
|
@ -165,7 +168,7 @@ class TestCreateCollection:
|
|||
collection_names = []
|
||||
|
||||
def create():
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
for i in range(threads_num):
|
||||
|
@ -176,9 +179,9 @@ class TestCreateCollection:
|
|||
for t in threads:
|
||||
t.join()
|
||||
|
||||
res = connect.list_collections()
|
||||
for item in collection_names:
|
||||
assert item in res
|
||||
assert item in connect.list_collections()
|
||||
connect.drop_collection(item)
|
||||
|
||||
|
||||
class TestCreateCollectionInvalid(object):
|
||||
|
@ -196,7 +199,7 @@ class TestCreateCollectionInvalid(object):
|
|||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_segment_row_count(self, request):
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
|
@ -221,21 +224,13 @@ class TestCreateCollectionInvalid(object):
|
|||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_invalid_segment_row_count(self, connect, get_segment_row_count):
|
||||
def test_create_collection_with_invalid_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
collection_name = gen_unique_str()
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_count
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
# @pytest.mark.level(2)
|
||||
# def test_create_collection_with_invalid_metric_type(self, connect, get_metric_type):
|
||||
# collection_name = gen_unique_str()
|
||||
# fields = copy.deepcopy(default_fields)
|
||||
# fields["fields"][-1]["params"]["metric_type"] = get_metric_type
|
||||
# with pytest.raises(Exception) as e:
|
||||
# connect.create_collection(collection_name, fields)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_with_invalid_dimension(self, connect, get_dim):
|
||||
dimension = get_dim
|
||||
|
@ -278,54 +273,55 @@ class TestCreateCollectionInvalid(object):
|
|||
method: create collection with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["fields"][-1]["params"].pop("dim")
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
def test_create_collection_no_segment_row_count(self, connect):
|
||||
def test_create_collection_no_segment_row_limit(self, connect):
|
||||
'''
|
||||
target: test create collection with no segment_row_count params
|
||||
method: create collection with corrent params
|
||||
expected: use default default_segment_row_count
|
||||
target: test create collection with no segment_row_limit params
|
||||
method: create collection with correct params
|
||||
expected: use default default_segment_row_limit
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields.pop("segment_row_limit")
|
||||
connect.create_collection(collection_name, fields)
|
||||
res = connect.get_collection_info(collection_name)
|
||||
logging.getLogger().info(res)
|
||||
assert res["segment_row_limit"] == default_segment_row_count
|
||||
assert res["segment_row_limit"] == default_server_segment_row_limit
|
||||
|
||||
# TODO: assert exception
|
||||
def test_create_collection_limit_fields(self, connect):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
limit_num = 64
|
||||
fields = copy.deepcopy(default_fields)
|
||||
for i in range(limit_num):
|
||||
field_name = gen_unique_str("field_name")
|
||||
field = {"field": field_name, "type": DataType.INT64}
|
||||
field = {"name": field_name, "type": DataType.INT64}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
# TODO: assert exception
|
||||
@pytest.mark.level(2)
|
||||
def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
field_name = get_invalid_string
|
||||
field = {"field": field_name, "type": DataType.INT64}
|
||||
field = {"name": field_name, "type": DataType.INT64}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
||||
# TODO: assert exception
|
||||
def test_create_collection_invalid_field_type(self, connect, get_field_type):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
field_type = get_field_type
|
||||
field = {"field": "test_field", "type": field_type}
|
||||
field = {"name": "test_field", "type": field_type}
|
||||
fields["fields"].append(field)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
|
|
@ -3,15 +3,14 @@ import pytest
|
|||
import logging
|
||||
import itertools
|
||||
from time import sleep
|
||||
import threading
|
||||
from multiprocessing import Process
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
uniq_id = "drop_collection"
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
|
||||
class TestDropCollection:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_collection` function
|
||||
|
@ -48,6 +47,33 @@ class TestDropCollection:
|
|||
with pytest.raises(Exception) as e:
|
||||
connect.drop_collection(collection_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_drop_collection_multithread(self, connect):
|
||||
'''
|
||||
target: test create and drop collection with multithread
|
||||
method: create and drop collection using multithread,
|
||||
expected: collections are created, and dropped
|
||||
'''
|
||||
threads_num = 8
|
||||
threads = []
|
||||
collection_names = []
|
||||
|
||||
def create():
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_names.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.drop_collection(collection_name)
|
||||
for i in range(threads_num):
|
||||
t = threading.Thread(target=create, args=())
|
||||
threads.append(t)
|
||||
t.start()
|
||||
time.sleep(0.2)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
for item in collection_names:
|
||||
assert not connect.has_collection(item)
|
||||
|
||||
|
||||
class TestDropCollectionInvalid(object):
|
||||
"""
|
||||
|
@ -60,6 +86,7 @@ class TestDropCollectionInvalid(object):
|
|||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_collection_with_invalid_collectionname(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
|
|
|
@ -6,13 +6,9 @@ from time import sleep
|
|||
import threading
|
||||
from multiprocessing import Process
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
nb = 1000
|
||||
collection_id = "info"
|
||||
default_fields = gen_default_fields()
|
||||
segment_row_count = 5000
|
||||
field_name = "float_vector"
|
||||
|
||||
uid = "collection_info"
|
||||
|
||||
class TestInfoBase:
|
||||
|
||||
|
@ -32,9 +28,9 @@ class TestInfoBase:
|
|||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_segment_row_counts()
|
||||
params=gen_segment_row_limits()
|
||||
)
|
||||
def get_segment_row_count(self, request):
|
||||
def get_segment_row_limit(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
|
@ -62,15 +58,15 @@ class TestInfoBase:
|
|||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == segment_row_count
|
||||
assert res['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(res["fields"]) == 2
|
||||
for field in res["fields"]:
|
||||
if field["type"] == filter_field:
|
||||
|
@ -79,25 +75,25 @@ class TestInfoBase:
|
|||
assert field["name"] == vector_field["name"]
|
||||
assert field["params"] == vector_field["params"]
|
||||
|
||||
def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
|
||||
def test_create_collection_segment_row_limit(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_count
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_count
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
# assert segment row count
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['segment_row_limit'] == get_segment_row_count
|
||||
assert res['segment_row_limit'] == get_segment_row_limit
|
||||
|
||||
def test_get_collection_info_after_index_created(self, connect, collection, get_simple_index):
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
res = connect.get_collection_info(collection)
|
||||
for field in res["fields"]:
|
||||
if field["field"] == field_name:
|
||||
if field["name"] == default_float_vec_field_name:
|
||||
index = field["indexes"][0]
|
||||
assert index["index_type"] == get_simple_index["index_type"]
|
||||
assert index["metric_type"] == get_simple_index["metric_type"]
|
||||
|
@ -119,7 +115,7 @@ class TestInfoBase:
|
|||
assert the value returned by get_collection_info method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.get_collection_info(connect, collection_name)
|
||||
|
||||
|
@ -132,7 +128,7 @@ class TestInfoBase:
|
|||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def get_info():
|
||||
|
@ -161,18 +157,18 @@ class TestInfoBase:
|
|||
'''
|
||||
filter_field = get_filter_field
|
||||
vector_field = get_vector_field
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], nb, vector_field["params"]["dim"])
|
||||
entities = gen_entities_by_fields(fields["fields"], default_nb, vector_field["params"]["dim"])
|
||||
res_ids = connect.insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == segment_row_count
|
||||
assert res['segment_row_limit'] == default_segment_row_limit
|
||||
assert len(res["fields"]) == 2
|
||||
for field in res["fields"]:
|
||||
if field["type"] == filter_field:
|
||||
|
@ -181,22 +177,22 @@ class TestInfoBase:
|
|||
assert field["name"] == vector_field["name"]
|
||||
assert field["params"] == vector_field["params"]
|
||||
|
||||
def test_create_collection_segment_row_count_after_insert(self, connect, get_segment_row_count):
|
||||
def test_create_collection_segment_row_limit_after_insert(self, connect, get_segment_row_limit):
|
||||
'''
|
||||
target: test create normal collection with different fields
|
||||
method: create collection with diff segment_row_count
|
||||
method: create collection with diff segment_row_limit
|
||||
expected: no exception raised
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_fields)
|
||||
fields["segment_row_limit"] = get_segment_row_count
|
||||
fields["segment_row_limit"] = get_segment_row_limit
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], nb, fields["fields"][-1]["params"]["dim"])
|
||||
entities = gen_entities_by_fields(fields["fields"], default_nb, fields["fields"][-1]["params"]["dim"])
|
||||
res_ids = connect.insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res = connect.get_collection_info(collection_name)
|
||||
assert res['auto_id'] == True
|
||||
assert res['segment_row_limit'] == get_segment_row_count
|
||||
assert res['segment_row_limit'] == get_segment_row_limit
|
||||
|
||||
|
||||
class TestInfoInvalid(object):
|
||||
|
|
|
@ -6,13 +6,11 @@ import threading
|
|||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
collection_id = "has_collection"
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
uid = "has_collection"
|
||||
|
||||
class TestHasCollection:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_collection` function
|
||||
|
@ -55,7 +53,7 @@ class TestHasCollection:
|
|||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def has():
|
||||
|
|
|
@ -6,15 +6,11 @@ import threading
|
|||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
|
||||
drop_interval_time = 3
|
||||
collection_id = "list_collections"
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
uid = "list_collections"
|
||||
|
||||
class TestListCollections:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `list_collections` function
|
||||
|
@ -36,7 +32,7 @@ class TestListCollections:
|
|||
'''
|
||||
collection_num = 50
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
assert collection_name in connect.list_collections()
|
||||
|
||||
|
@ -57,7 +53,7 @@ class TestListCollections:
|
|||
assert the value returned by list_collections method
|
||||
expected: False
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
assert collection_name not in connect.list_collections()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -72,7 +68,7 @@ class TestListCollections:
|
|||
if result:
|
||||
for collection_name in result:
|
||||
connect.drop_collection(collection_name)
|
||||
time.sleep(drop_interval_time)
|
||||
time.sleep(default_drop_interval)
|
||||
result = connect.list_collections()
|
||||
assert len(result) == 0
|
||||
|
||||
|
@ -85,7 +81,7 @@ class TestListCollections:
|
|||
'''
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
|
||||
def _list():
|
||||
|
|
|
@ -5,18 +5,11 @@ import itertools
|
|||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
collection_id = "load_collection"
|
||||
nb = 6000
|
||||
default_fields = gen_default_fields()
|
||||
entities = gen_entities(nb)
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
|
||||
uid = "load_collection"
|
||||
|
||||
class TestLoadBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `load_collection` function
|
||||
|
@ -49,10 +42,10 @@ class TestLoadBase:
|
|||
method: insert and create index, load collection with correct params
|
||||
expected: no error raised
|
||||
'''
|
||||
connect.insert(collection, entities)
|
||||
connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
logging.getLogger().info(get_simple_index)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
connect.load_collection(collection)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -62,16 +55,16 @@ class TestLoadBase:
|
|||
method: insert and create index, load binary_collection with correct params
|
||||
expected: no error raised
|
||||
'''
|
||||
connect.insert(binary_collection, binary_entities)
|
||||
connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
for metric_type in binary_metrics():
|
||||
logging.getLogger().info(metric_type)
|
||||
get_binary_index["metric_type"] = metric_type
|
||||
if get_binary_index["index_type"] == "BIN_IVF_FLAT" and metric_type in structure_metrics():
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(binary_collection, binary_field_name, get_binary_index)
|
||||
connect.create_index(binary_collection, default_binary_vec_field_name, get_binary_index)
|
||||
else:
|
||||
connect.create_index(binary_collection, binary_field_name, get_binary_index)
|
||||
connect.create_index(binary_collection, default_binary_vec_field_name, get_binary_index)
|
||||
connect.load_collection(binary_collection)
|
||||
|
||||
def load_empty_collection(self, connect, collection):
|
||||
|
@ -94,7 +87,7 @@ class TestLoadBase:
|
|||
|
||||
@pytest.mark.level(2)
|
||||
def test_load_collection_not_existed(self, connect, collection):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.load_collection(collection_name)
|
||||
|
||||
|
|
|
@ -71,6 +71,7 @@ def connect(request):
|
|||
port = http_port
|
||||
try:
|
||||
milvus = get_milvus(host=ip, port=port, handler=handler)
|
||||
# reset_build_index_threshold(milvus)
|
||||
except Exception as e:
|
||||
logging.getLogger().error(str(e))
|
||||
pytest.exit("Milvus server can not connected, exit pytest ...")
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
import utils
|
||||
|
||||
default_fields = utils.gen_default_fields()
|
||||
default_binary_fields = utils.gen_binary_default_fields()
|
||||
|
||||
default_entity = utils.gen_entities(1)
|
||||
default_raw_binary_vector, default_binary_entity = utils.gen_binary_entities(1)
|
||||
|
||||
default_entities = utils.gen_entities(utils.default_nb)
|
||||
default_raw_binary_vectors, default_binary_entities = utils.gen_binary_entities(utils.default_nb)
|
|
@ -7,23 +7,13 @@ import logging
|
|||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
collection_id = "test_delete"
|
||||
DELETE_TIMEOUT = 60
|
||||
tag = "1970_01_01"
|
||||
nb = 6000
|
||||
field_name = default_float_vec_field_name
|
||||
entity = gen_entities(1)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": 10, "metric_type":"L2","query": gen_vectors(1, dim), "params": {"nprobe": 10}}}}
|
||||
{"vector": {field_name: {"topk": 10, "metric_type":"L2", "query": gen_vectors(1, default_dim), "params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -51,7 +41,7 @@ class TestDeleteBase:
|
|||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
6000
|
||||
2000
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
|
@ -63,7 +53,7 @@ class TestDeleteBase:
|
|||
method: add entity and delete
|
||||
expected: status DELETED
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
status = connect.delete_entity_by_id(collection, [0])
|
||||
assert status
|
||||
|
@ -93,7 +83,7 @@ class TestDeleteBase:
|
|||
method: add entity and delete
|
||||
expected: error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
collection_new = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
|
@ -121,14 +111,14 @@ class TestDeleteBase:
|
|||
method: add entities and delete one in collection, and one not in collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], 1]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb - 1
|
||||
assert res_count == default_nb - 1
|
||||
|
||||
def test_insert_delete_B(self, connect, id_collection):
|
||||
'''
|
||||
|
@ -136,8 +126,8 @@ class TestDeleteBase:
|
|||
method: add entities with the same ids, and delete the id in collection
|
||||
expected: no error raised, all entities deleted
|
||||
'''
|
||||
ids = [1 for i in range(nb)]
|
||||
res_ids = connect.insert(id_collection, entities, ids)
|
||||
ids = [1 for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
connect.flush([id_collection])
|
||||
delete_ids = [1]
|
||||
status = connect.delete_entity_by_id(id_collection, delete_ids)
|
||||
|
@ -152,7 +142,7 @@ class TestDeleteBase:
|
|||
method: add one entity and delete two ids
|
||||
expected: error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], 1]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -166,14 +156,14 @@ class TestDeleteBase:
|
|||
method: add entities and delete, then flush
|
||||
expected: entity deleted and no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb - len(delete_ids)
|
||||
assert res_count == default_nb - len(delete_ids)
|
||||
|
||||
def test_flush_after_delete_binary(self, connect, binary_collection):
|
||||
'''
|
||||
|
@ -181,21 +171,21 @@ class TestDeleteBase:
|
|||
method: add entities and delete, then flush
|
||||
expected: entity deleted and no error raised
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(binary_collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([binary_collection])
|
||||
res_count = connect.count_entities(binary_collection)
|
||||
assert res_count == nb - len(delete_ids)
|
||||
assert res_count == default_nb - len(delete_ids)
|
||||
|
||||
def test_insert_delete_binary(self, connect, binary_collection):
|
||||
'''
|
||||
method: add entities and delete
|
||||
expected: status DELETED
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(binary_collection, delete_ids)
|
||||
|
@ -206,34 +196,34 @@ class TestDeleteBase:
|
|||
expected: status DELETED
|
||||
note: Not flush after delete
|
||||
'''
|
||||
insert_ids = [i for i in range(nb)]
|
||||
ids = connect.insert(id_collection, entities, insert_ids)
|
||||
insert_ids = [i for i in range(default_nb)]
|
||||
ids = connect.insert(id_collection, default_entities, insert_ids)
|
||||
connect.flush([id_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(id_collection, delete_ids)
|
||||
assert status
|
||||
new_ids = connect.insert(id_collection, entity, [ids[0]])
|
||||
new_ids = connect.insert(id_collection, default_entity, [ids[0]])
|
||||
assert new_ids == [ids[0]]
|
||||
connect.flush([id_collection])
|
||||
res_count = connect.count_entities(id_collection)
|
||||
assert res_count == nb - 1
|
||||
assert res_count == default_nb - 1
|
||||
|
||||
def test_insert_same_ids_after_delete_binary(self, connect, binary_id_collection):
|
||||
'''
|
||||
method: add entities, with the same id and delete the ids
|
||||
expected: status DELETED, all id deleted
|
||||
'''
|
||||
insert_ids = [i for i in range(nb)]
|
||||
ids = connect.insert(binary_id_collection, binary_entities, insert_ids)
|
||||
insert_ids = [i for i in range(default_nb)]
|
||||
ids = connect.insert(binary_id_collection, default_binary_entities, insert_ids)
|
||||
connect.flush([binary_id_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(binary_id_collection, delete_ids)
|
||||
assert status
|
||||
new_ids = connect.insert(binary_id_collection, binary_entity, [ids[0]])
|
||||
new_ids = connect.insert(binary_id_collection, default_binary_entity, [ids[0]])
|
||||
assert new_ids == [ids[0]]
|
||||
connect.flush([binary_id_collection])
|
||||
res_count = connect.count_entities(binary_id_collection)
|
||||
assert res_count == nb - 1
|
||||
assert res_count == default_nb - 1
|
||||
|
||||
def test_search_after_delete(self, connect, collection):
|
||||
'''
|
||||
|
@ -241,13 +231,14 @@ class TestDeleteBase:
|
|||
method: add entities and delete, then search
|
||||
expected: entity deleted and no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
query = copy.deepcopy(default_single_query)
|
||||
query["bool"]["must"][0]["vector"][field_name]["query"] = [entity[-1]["values"][0], entities[-1]["values"][0], entities[-1]["values"][-1]]
|
||||
query["bool"]["must"][0]["vector"][field_name]["query"] =\
|
||||
[default_entity[-1]["values"][0], default_entities[-1]["values"][0], default_entities[-1]["values"][-1]]
|
||||
res = connect.search(collection, query)
|
||||
logging.getLogger().debug(res)
|
||||
assert len(res) == len(query["bool"]["must"][0]["vector"][field_name]["query"])
|
||||
|
@ -260,7 +251,7 @@ class TestDeleteBase:
|
|||
method: add entitys and delete, then create index
|
||||
expected: vectors deleted, index created
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -272,7 +263,7 @@ class TestDeleteBase:
|
|||
method: add entities and delete id serveral times
|
||||
expected: entities deleted
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -288,14 +279,14 @@ class TestDeleteBase:
|
|||
expected: entities deleted
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb - len(delete_ids)
|
||||
assert res_count == default_nb - len(delete_ids)
|
||||
res_get = connect.get_entity_by_id(collection, delete_ids)
|
||||
assert res_get[0] is None
|
||||
|
||||
|
@ -305,17 +296,17 @@ class TestDeleteBase:
|
|||
method: create index, insert entities, and delete
|
||||
expected: entities deleted
|
||||
'''
|
||||
ids = [i for i in range(nb)]
|
||||
ids = [i for i in range(default_nb)]
|
||||
connect.create_index(id_collection, field_name, get_simple_index)
|
||||
for i in range(nb):
|
||||
connect.insert(id_collection, entity, [ids[i]])
|
||||
for i in range(default_nb):
|
||||
connect.insert(id_collection, default_entity, [ids[i]])
|
||||
connect.flush([id_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(id_collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([id_collection])
|
||||
res_count = connect.count_entities(id_collection)
|
||||
assert res_count == nb - len(delete_ids)
|
||||
assert res_count == default_nb - len(delete_ids)
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -327,30 +318,30 @@ class TestDeleteBase:
|
|||
method: add entitys with given tag, delete entities with the return ids
|
||||
expected: entities deleted
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb - 2
|
||||
assert res_count == default_nb - 2
|
||||
|
||||
def test_insert_default_tag_delete(self, connect, collection):
|
||||
'''
|
||||
method: add entitys, delete entities with the return ids
|
||||
expected: entities deleted
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb - 2
|
||||
assert res_count == default_nb - 2
|
||||
|
||||
def test_insert_tags_delete(self, connect, collection):
|
||||
'''
|
||||
|
@ -358,17 +349,17 @@ class TestDeleteBase:
|
|||
expected: entities deleted
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids_new = connect.insert(collection, entities, partition_tag=tag_new)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
ids_new = connect.insert(collection, default_entities, partition_tag=tag_new)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids_new[0]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == 2 * (nb - 1)
|
||||
assert res_count == 2 * (default_nb - 1)
|
||||
|
||||
def test_insert_tags_index_delete(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
@ -376,10 +367,10 @@ class TestDeleteBase:
|
|||
expected: entities deleted
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids_new = connect.insert(collection, entities, partition_tag=tag_new)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
ids_new = connect.insert(collection, default_entities, partition_tag=tag_new)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
delete_ids = [ids[0], ids_new[0]]
|
||||
|
@ -387,7 +378,7 @@ class TestDeleteBase:
|
|||
assert status
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == 2 * (nb - 1)
|
||||
assert res_count == 2 * (default_nb - 1)
|
||||
|
||||
|
||||
class TestDeleteInvalid(object):
|
||||
|
@ -420,6 +411,7 @@ class TestDeleteInvalid(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
status = connect.delete_entity_by_id(collection, [1, invalid_id])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_delete_entity_with_invalid_collection_name(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception) as e:
|
||||
|
|
|
@ -2,28 +2,19 @@ import time
|
|||
import random
|
||||
import pdb
|
||||
import copy
|
||||
import threading
|
||||
import logging
|
||||
from multiprocessing import Pool, Process
|
||||
import concurrent.futures
|
||||
from threading import current_thread
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
collection_id = "test_get"
|
||||
DELETE_TIMEOUT = 60
|
||||
tag = "1970_01_01"
|
||||
nb = 6000
|
||||
entity = gen_entities(1)
|
||||
binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, dim), "params": {"nprobe": 10}}}}
|
||||
{"vector": {
|
||||
default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -34,6 +25,7 @@ class TestGetBase:
|
|||
The following cases are used to test `get_entity_by_id` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index()
|
||||
|
@ -48,8 +40,6 @@ class TestGetBase:
|
|||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
10,
|
||||
100,
|
||||
500
|
||||
],
|
||||
)
|
||||
|
@ -62,13 +52,13 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb
|
||||
assert res_count == default_nb
|
||||
get_ids = [ids[get_pos]]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][get_pos])
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][get_pos])
|
||||
|
||||
def test_get_entity_multi_ids(self, connect, collection, get_pos):
|
||||
'''
|
||||
|
@ -76,12 +66,12 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
def test_get_entity_parts_ids(self, connect, collection):
|
||||
'''
|
||||
|
@ -89,12 +79,12 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = [ids[0], 1, ids[-1]]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][0])
|
||||
assert_equal_vector(res[-1].get(default_float_vec_field_name), entities[-1]["values"][-1])
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][0])
|
||||
assert_equal_vector(res[-1].get(default_float_vec_field_name), default_entities[-1]["values"][-1])
|
||||
assert res[1] is None
|
||||
|
||||
def test_get_entity_limit(self, connect, collection, args):
|
||||
|
@ -106,7 +96,7 @@ class TestGetBase:
|
|||
if args["handler"] == "HTTP":
|
||||
pytest.skip("skip in http mode")
|
||||
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.get_entity_by_id(collection, ids)
|
||||
|
@ -117,13 +107,13 @@ class TestGetBase:
|
|||
method: add entity, and get one id
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = [1 for i in range(nb)]
|
||||
res_ids = connect.insert(id_collection, entities, ids)
|
||||
ids = [1 for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
connect.flush([id_collection])
|
||||
get_ids = [ids[0]]
|
||||
res = connect.get_entity_by_id(id_collection, get_ids)
|
||||
assert len(res) == 1
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), entities[-1]["values"][0])
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][0])
|
||||
|
||||
def test_get_entity_params_same_ids(self, connect, id_collection):
|
||||
'''
|
||||
|
@ -132,14 +122,14 @@ class TestGetBase:
|
|||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = [1]
|
||||
res_ids = connect.insert(id_collection, entity, ids)
|
||||
res_ids = connect.insert(id_collection, default_entity, ids)
|
||||
connect.flush([id_collection])
|
||||
get_ids = [1, 1]
|
||||
res = connect.get_entity_by_id(id_collection, get_ids)
|
||||
assert len(res) == len(get_ids)
|
||||
for i in range(len(get_ids)):
|
||||
logging.getLogger().info(i)
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entity[-1]["values"][0])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entity[-1]["values"][0])
|
||||
|
||||
def test_get_entities_params_same_ids(self, connect, collection):
|
||||
'''
|
||||
|
@ -147,13 +137,13 @@ class TestGetBase:
|
|||
method: add entities, and get entity with the same ids
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
res_ids = connect.insert(collection, entities)
|
||||
res_ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = [res_ids[0], res_ids[0]]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
assert len(res) == len(get_ids)
|
||||
for i in range(len(get_ids)):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][0])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][0])
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -167,12 +157,12 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
get_ids = [ids[0], 1, ids[-1]]
|
||||
res = connect.get_entity_by_id(binary_collection, get_ids)
|
||||
assert_equal_vector(res[0].get("binary_vector"), binary_entities[-1]["values"][0])
|
||||
assert_equal_vector(res[-1].get("binary_vector"), binary_entities[-1]["values"][-1])
|
||||
assert_equal_vector(res[0].get("binary_vector"), default_binary_entities[-1]["values"][0])
|
||||
assert_equal_vector(res[-1].get("binary_vector"), default_binary_entities[-1]["values"][-1])
|
||||
assert res[1] is None
|
||||
|
||||
"""
|
||||
|
@ -180,19 +170,20 @@ class TestGetBase:
|
|||
The following cases are used to test `get_entity_by_id` function, with tags
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_get_entities_tag(self, connect, collection, get_pos):
|
||||
'''
|
||||
target: test.get_entity_by_id
|
||||
method: add entities with tag, get
|
||||
expected: entity returned
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
|
||||
connect.flush([collection])
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
def test_get_entities_tag_default(self, connect, collection, get_pos):
|
||||
'''
|
||||
|
@ -200,13 +191,13 @@ class TestGetBase:
|
|||
method: add entities with default tag, get
|
||||
expected: entity returned
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
def test_get_entities_tags_default(self, connect, collection, get_pos):
|
||||
'''
|
||||
|
@ -215,14 +206,14 @@ class TestGetBase:
|
|||
expected: entity returned
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
def test_get_entities_tags_A(self, connect, collection, get_pos):
|
||||
'''
|
||||
|
@ -231,14 +222,14 @@ class TestGetBase:
|
|||
expected: entity returned
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
|
||||
connect.flush([collection])
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
def test_get_entities_tags_B(self, connect, collection, get_pos):
|
||||
'''
|
||||
|
@ -247,19 +238,19 @@ class TestGetBase:
|
|||
expected: entity returned
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
new_entities = gen_entities(nb+1)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids_new = connect.insert(collection, new_entities, partition_tag=tag_new)
|
||||
new_entities = gen_entities(default_nb + 1)
|
||||
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
|
||||
ids_new = connect.insert(collection, new_entities, partition_tag = tag_new)
|
||||
connect.flush([collection])
|
||||
get_ids = ids[:get_pos]
|
||||
get_ids.extend(ids_new[:get_pos])
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
for i in range(get_pos, get_pos*2):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), new_entities[-1]["values"][i-get_pos])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
for i in range(get_pos, get_pos * 2):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), new_entities[-1]["values"][i - get_pos])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_entities_indexed_tag(self, connect, collection, get_simple_index, get_pos):
|
||||
|
@ -268,27 +259,28 @@ class TestGetBase:
|
|||
method: add entities with tag, get
|
||||
expected: entity returned
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `get_entity_by_id` function, with fields params
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_get_entity_field(self, connect, collection, get_pos):
|
||||
'''
|
||||
target: test.get_entity_by_id, get one
|
||||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = [ids[get_pos]]
|
||||
fields = ["int64"]
|
||||
|
@ -296,7 +288,7 @@ class TestGetBase:
|
|||
# assert fields
|
||||
res = res.dict()
|
||||
assert res[0]["field"] == fields[0]
|
||||
assert res[0]["values"] == [entities[0]["values"][get_pos]]
|
||||
assert res[0]["values"] == [default_entities[0]["values"][get_pos]]
|
||||
assert res[0]["type"] == DataType.INT64
|
||||
|
||||
def test_get_entity_fields(self, connect, collection, get_pos):
|
||||
|
@ -305,7 +297,7 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = [ids[get_pos]]
|
||||
fields = ["int64", "float", default_float_vec_field_name]
|
||||
|
@ -315,11 +307,11 @@ class TestGetBase:
|
|||
assert len(res) == len(fields)
|
||||
for field in res:
|
||||
if field["field"] == fields[0]:
|
||||
assert field["values"] == [entities[0]["values"][get_pos]]
|
||||
assert field["values"] == [default_entities[0]["values"][get_pos]]
|
||||
elif field["field"] == fields[1]:
|
||||
assert field["values"] == [entities[1]["values"][get_pos]]
|
||||
assert field["values"] == [default_entities[1]["values"][get_pos]]
|
||||
else:
|
||||
assert_equal_vector(field["values"][0], entities[-1]["values"][get_pos])
|
||||
assert_equal_vector(field["values"][0], default_entities[-1]["values"][get_pos])
|
||||
|
||||
# TODO: assert exception
|
||||
def test_get_entity_field_not_match(self, connect, collection, get_pos):
|
||||
|
@ -328,7 +320,7 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = [ids[get_pos]]
|
||||
fields = ["int1288"]
|
||||
|
@ -342,7 +334,7 @@ class TestGetBase:
|
|||
method: add entity, and get
|
||||
expected: entity returned equals insert
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_ids = [ids[get_pos]]
|
||||
fields = ["int1288"]
|
||||
|
@ -355,9 +347,9 @@ class TestGetBase:
|
|||
method: add entity and get
|
||||
expected: empty result
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
res = connect.get_entity_by_id(collection, [1])
|
||||
res = connect.get_entity_by_id(collection, [1])
|
||||
assert res[0] is None
|
||||
|
||||
def test_get_entity_collection_not_existed(self, connect, collection):
|
||||
|
@ -366,7 +358,7 @@ class TestGetBase:
|
|||
method: add entity and get
|
||||
expected: error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
collection_new = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
|
@ -377,13 +369,14 @@ class TestGetBase:
|
|||
The following cases are used to test `get_entity_by_id` function, after deleted
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_get_entity_after_delete(self, connect, collection, get_pos):
|
||||
'''
|
||||
target: test.get_entity_by_id
|
||||
method: add entities, and delete, get entity by the given id
|
||||
expected: empty result
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[get_pos]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -398,7 +391,7 @@ class TestGetBase:
|
|||
method: add entities, and delete, get entity by the given id
|
||||
expected: empty result
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = ids[:get_pos]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -414,7 +407,7 @@ class TestGetBase:
|
|||
method: add entities, and delete, get entity by the given id
|
||||
expected: empty result
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = ids[:get_pos]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -431,13 +424,13 @@ class TestGetBase:
|
|||
method: add entities batch, create index, get
|
||||
expected: entity returned
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_entities_indexed_single(self, connect, collection, get_simple_index, get_pos):
|
||||
|
@ -447,14 +440,31 @@ class TestGetBase:
|
|||
expected: entity returned
|
||||
'''
|
||||
ids = []
|
||||
for i in range(nb):
|
||||
ids.append(connect.insert(collection, entity)[0])
|
||||
for i in range(default_nb):
|
||||
ids.append(connect.insert(collection, default_entity)[0])
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, default_float_vec_field_name, get_simple_index)
|
||||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entity[-1]["values"][0])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entity[-1]["values"][0])
|
||||
|
||||
def test_get_entities_with_deleted_ids(self, connect, id_collection):
|
||||
'''
|
||||
target: test.get_entity_by_id
|
||||
method: add entities ids, and delete part, get entity include the deleted id
|
||||
expected:
|
||||
'''
|
||||
ids = [i for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
connect.flush([id_collection])
|
||||
status = connect.delete_entity_by_id(id_collection, [res_ids[1]])
|
||||
connect.flush([id_collection])
|
||||
get_ids = res_ids[:2]
|
||||
res = connect.get_entity_by_id(id_collection, get_ids)
|
||||
assert len(res) == len(get_ids)
|
||||
assert_equal_vector(res[0].get(default_float_vec_field_name), default_entities[-1]["values"][0])
|
||||
assert res[1] is None
|
||||
|
||||
# TODO: unable to set config
|
||||
def _test_get_entities_after_delete_disable_autoflush(self, connect, collection, get_pos):
|
||||
|
@ -463,7 +473,7 @@ class TestGetBase:
|
|||
method: disable autoflush, add entities, and delete, get entity by the given id
|
||||
expected: empty result
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = ids[:get_pos]
|
||||
try:
|
||||
|
@ -472,7 +482,7 @@ class TestGetBase:
|
|||
get_ids = ids[:get_pos]
|
||||
res = connect.get_entity_by_id(collection, get_ids)
|
||||
for i in range(get_pos):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
finally:
|
||||
enable_flush(connect)
|
||||
|
||||
|
@ -482,9 +492,9 @@ class TestGetBase:
|
|||
method: add entities with the same ids, and delete, get entity by the given id
|
||||
expected: empty result
|
||||
'''
|
||||
ids = [i for i in range(nb)]
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids[0] = 1
|
||||
res_ids = connect.insert(id_collection, entities, ids)
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
connect.flush([id_collection])
|
||||
status = connect.delete_entity_by_id(id_collection, [1])
|
||||
connect.flush([id_collection])
|
||||
|
@ -498,8 +508,8 @@ class TestGetBase:
|
|||
method: add entities into partition, and delete, get entity by the given id
|
||||
expected: get one entity
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag = default_tag)
|
||||
connect.flush([collection])
|
||||
status = connect.delete_entity_by_id(collection, [ids[get_pos]])
|
||||
connect.flush([collection])
|
||||
|
@ -507,14 +517,16 @@ class TestGetBase:
|
|||
assert res[0] is None
|
||||
|
||||
def test_get_entity_by_id_multithreads(self, connect, collection):
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_id = ids[100:200]
|
||||
|
||||
def get():
|
||||
res = connect.get_entity_by_id(collection, get_id)
|
||||
assert len(res) == len(get_id)
|
||||
for i in range(len(res)):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][100+i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][100 + i])
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||
future_results = {executor.submit(
|
||||
get): i for i in range(10)}
|
||||
|
@ -528,14 +540,14 @@ class TestGetBase:
|
|||
method: thread do insert and get
|
||||
expected:
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
get_id = ids[:1000]
|
||||
|
||||
def insert():
|
||||
# logging.getLogger().info(current_thread().getName() + " insert")
|
||||
step = 1000
|
||||
for i in range(nb // step):
|
||||
for i in range(default_nb // step):
|
||||
group_entities = gen_entities(step, False)
|
||||
connect.insert(collection, group_entities)
|
||||
connect.flush([collection])
|
||||
|
@ -545,7 +557,7 @@ class TestGetBase:
|
|||
res = connect.get_entity_by_id(collection, get_id)
|
||||
assert len(res) == len(get_id)
|
||||
for i in range(len(res)):
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), entities[-1]["values"][i])
|
||||
assert_equal_vector(res[i].get(default_float_vec_field_name), default_entities[-1]["values"][i])
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
||||
for i in range(20):
|
||||
|
@ -554,7 +566,7 @@ class TestGetBase:
|
|||
future.result()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_entity_by_id_insert_multi_threads(self, connect, collection):
|
||||
def test_get_entity_by_id_insert_multi_threads_2(self, connect, collection):
|
||||
'''
|
||||
target: test.get_entity_by_id
|
||||
method: thread do insert and get
|
||||
|
@ -572,16 +584,16 @@ class TestGetBase:
|
|||
# logging.getLogger().info(current_thread().getName() + " insert")
|
||||
for group_vector in group_vectors:
|
||||
group_entities = [
|
||||
{"field": "int64", "type": DataType.INT64, "values": [i for i in range(step)]},
|
||||
{"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(step)]},
|
||||
{"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": group_vector}
|
||||
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(step)]},
|
||||
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(step)]},
|
||||
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": group_vector}
|
||||
]
|
||||
group_ids = connect.insert(collection, group_entities)
|
||||
connect.flush([collection])
|
||||
executor.submit(get, group_ids, group_entities)
|
||||
|
||||
step = 100
|
||||
vectors = gen_vectors(nb, dimension, False)
|
||||
vectors = gen_vectors(default_nb, default_dim, False)
|
||||
group_vectors = [vectors[i:i + step] for i in range(0, len(vectors), step)]
|
||||
task = executor.submit(insert, group_vectors)
|
||||
task.result()
|
||||
|
@ -591,6 +603,7 @@ class TestGetInvalid(object):
|
|||
"""
|
||||
Test get entities with invalid params
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_strs()
|
||||
|
@ -620,7 +633,7 @@ class TestGetInvalid(object):
|
|||
expected: raise an exception
|
||||
'''
|
||||
entity_id = get_entity_id
|
||||
ids = [entity_id for _ in range(nb)]
|
||||
ids = [entity_id for _ in range(default_nb)]
|
||||
with pytest.raises(Exception):
|
||||
connect.get_entity_by_id(collection, ids)
|
||||
|
||||
|
@ -632,7 +645,7 @@ class TestGetInvalid(object):
|
|||
expected: raise an exception
|
||||
'''
|
||||
entity_id = get_entity_id
|
||||
ids = [i for i in range(nb)]
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids[-1] = entity_id
|
||||
with pytest.raises(Exception):
|
||||
connect.get_entity_by_id(collection, ids)
|
||||
|
@ -650,4 +663,4 @@ class TestGetInvalid(object):
|
|||
ids = [1]
|
||||
fields = [field_name]
|
||||
with pytest.raises(Exception):
|
||||
res = connect.get_entity_by_id(collection, ids, fields=fields)
|
||||
res = connect.get_entity_by_id(collection, ids, fields = fields)
|
||||
|
|
|
@ -1,30 +1,22 @@
|
|||
import logging
|
||||
import time
|
||||
import pdb
|
||||
import copy
|
||||
import threading
|
||||
import logging
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from milvus import DataType
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
collection_id = "test_insert"
|
||||
ADD_TIMEOUT = 60
|
||||
tag = "1970_01_01"
|
||||
insert_interval_time = 1.5
|
||||
nb = 6000
|
||||
field_name = default_float_vec_field_name
|
||||
entity = gen_entities(1)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
uid = "test_insert"
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim),"metric_type":"L2",
|
||||
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type": "L2",
|
||||
"params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
|
@ -89,9 +81,9 @@ class TestInsertBase:
|
|||
method: insert entity into a random named collection
|
||||
expected: error raised
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.insert(collection_name, entities)
|
||||
connect.insert(collection_name, default_entities)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_drop_collection(self, connect, collection):
|
||||
|
@ -100,7 +92,7 @@ class TestInsertBase:
|
|||
method: insert vector and delete collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
assert len(ids) == 1
|
||||
connect.drop_collection(collection)
|
||||
|
||||
|
@ -111,7 +103,7 @@ class TestInsertBase:
|
|||
method: insert vector, sleep, and delete collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
assert len(ids) == 1
|
||||
connect.flush([collection])
|
||||
connect.drop_collection(collection)
|
||||
|
@ -123,10 +115,15 @@ class TestInsertBase:
|
|||
method: insert vector and build index
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
info = connect.get_collection_info(collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == field_name:
|
||||
assert field["indexes"][0] == get_simple_index
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_after_create_index(self, connect, collection, get_simple_index):
|
||||
|
@ -136,8 +133,13 @@ class TestInsertBase:
|
|||
expected: no error raised
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
info = connect.get_collection_info(collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == field_name:
|
||||
assert field["indexes"][0] == get_simple_index
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_search(self, connect, collection):
|
||||
|
@ -146,17 +148,27 @@ class TestInsertBase:
|
|||
method: insert vector, sleep, and search collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
res = connect.search(collection, default_single_query)
|
||||
logging.getLogger().debug(res)
|
||||
assert res
|
||||
|
||||
def test_insert_segment_row_count(self, connect, collection):
|
||||
nb = default_segment_row_limit + 1
|
||||
res_ids = connect.insert(collection, gen_entities(nb))
|
||||
connect.flush([collection])
|
||||
assert len(res_ids) == nb
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert len(stats['partitions'][0]['segments']) == 2
|
||||
for segment in stats['partitions'][0]['segments']:
|
||||
assert segment['row_count'] in [default_segment_row_limit, 1]
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
1,
|
||||
6000
|
||||
2000
|
||||
],
|
||||
)
|
||||
def insert_count(self, request):
|
||||
|
@ -207,7 +219,7 @@ class TestInsertBase:
|
|||
collection_name = gen_unique_str("test_collection")
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count,
|
||||
"segment_row_limit": default_segment_row_limit,
|
||||
"auto_id": True
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
|
@ -228,10 +240,10 @@ class TestInsertBase:
|
|||
method: test insert vectors twice, use customize ids first, and then use no ids
|
||||
expected: error raised
|
||||
'''
|
||||
ids = [i for i in range(nb)]
|
||||
res_ids = connect.insert(id_collection, entities, ids)
|
||||
ids = [i for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids_new = connect.insert(id_collection, entities)
|
||||
res_ids_new = connect.insert(id_collection, default_entities)
|
||||
|
||||
# TODO: assert exception && enable
|
||||
@pytest.mark.level(2)
|
||||
|
@ -243,7 +255,7 @@ class TestInsertBase:
|
|||
expected: error raised
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(id_collection, entities)
|
||||
res_ids = connect.insert(id_collection, default_entities)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_length_not_match_batch(self, connect, id_collection):
|
||||
|
@ -252,10 +264,10 @@ class TestInsertBase:
|
|||
method: create collection and insert vectors in it
|
||||
expected: raise an exception
|
||||
'''
|
||||
ids = [i for i in range(1, nb)]
|
||||
ids = [i for i in range(1, default_nb)]
|
||||
logging.getLogger().info(len(ids))
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(id_collection, entities, ids)
|
||||
res_ids = connect.insert(id_collection, default_entities, ids)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_length_not_match_single(self, connect, collection):
|
||||
|
@ -264,10 +276,10 @@ class TestInsertBase:
|
|||
method: create collection and insert vectors in it
|
||||
expected: raise an exception
|
||||
'''
|
||||
ids = [i for i in range(1, nb)]
|
||||
ids = [i for i in range(1, default_nb)]
|
||||
logging.getLogger().info(len(ids))
|
||||
with pytest.raises(Exception) as e:
|
||||
res_ids = connect.insert(collection, entity, ids)
|
||||
res_ids = connect.insert(collection, default_entity, ids)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field):
|
||||
|
@ -282,10 +294,10 @@ class TestInsertBase:
|
|||
collection_name = gen_unique_str("test_collection")
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count
|
||||
"segment_row_limit": default_segment_row_limit
|
||||
}
|
||||
connect.create_collection(collection_name, fields)
|
||||
entities = gen_entities_by_fields(fields["fields"], nb, dim)
|
||||
entities = gen_entities_by_fields(fields["fields"], nb, default_dim)
|
||||
res_ids = connect.insert(collection_name, entities)
|
||||
connect.flush([collection_name])
|
||||
res_count = connect.count_entities(collection_name)
|
||||
|
@ -298,9 +310,10 @@ class TestInsertBase:
|
|||
method: create collection and insert entities in it, with the partition_tag param
|
||||
expected: the collection row count equals to nq
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
assert len(ids) == nb
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
assert connect.has_partition(collection, default_tag)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag_with_ids(self, connect, id_collection):
|
||||
|
@ -309,9 +322,9 @@ class TestInsertBase:
|
|||
method: create collection and insert entities in it, with the partition_tag param
|
||||
expected: the collection row count equals to nq
|
||||
'''
|
||||
connect.create_partition(id_collection, tag)
|
||||
ids = [i for i in range(nb)]
|
||||
res_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
res_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
assert res_ids == ids
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
|
@ -321,12 +334,12 @@ class TestInsertBase:
|
|||
method: create partition and insert info collection without tag params
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
assert len(ids) == nb
|
||||
assert len(ids) == default_nb
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == nb
|
||||
assert res_count == default_nb
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag_not_existed(self, connect, collection):
|
||||
|
@ -337,7 +350,7 @@ class TestInsertBase:
|
|||
'''
|
||||
tag = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=tag)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_tag_existed(self, connect, collection):
|
||||
|
@ -346,12 +359,12 @@ class TestInsertBase:
|
|||
method: create collection and insert entities in it repeatly, with the partition_tag param
|
||||
expected: the collection row count equals to nq
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
assert res_count == 2 * nb
|
||||
assert res_count == 2 * default_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_without_connect(self, dis_connect, collection):
|
||||
|
@ -361,7 +374,7 @@ class TestInsertBase:
|
|||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = dis_connect.insert(collection, entities)
|
||||
ids = dis_connect.insert(collection, default_entities)
|
||||
|
||||
def test_insert_collection_not_existed(self, connect):
|
||||
'''
|
||||
|
@ -370,7 +383,7 @@ class TestInsertBase:
|
|||
expected: error raised
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(gen_unique_str("not_exist_collection"), entities)
|
||||
ids = connect.insert(gen_unique_str("not_exist_collection"), default_entities)
|
||||
|
||||
def test_insert_dim_not_matched(self, connect, collection):
|
||||
'''
|
||||
|
@ -378,8 +391,8 @@ class TestInsertBase:
|
|||
method: the entities dimension is half of the collection dimension, check the status
|
||||
expected: error raised
|
||||
'''
|
||||
vectors = gen_vectors(nb, int(dim) // 2)
|
||||
insert_entities = copy.deepcopy(entities)
|
||||
vectors = gen_vectors(default_nb, int(default_dim) // 2)
|
||||
insert_entities = copy.deepcopy(default_entities)
|
||||
insert_entities[-1]["values"] = vectors
|
||||
with pytest.raises(Exception) as e:
|
||||
ids = connect.insert(collection, insert_entities)
|
||||
|
@ -390,7 +403,7 @@ class TestInsertBase:
|
|||
method: update entity field name
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = update_field_name(copy.deepcopy(entity), "int64", "int64new")
|
||||
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", "int64new")
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -401,7 +414,7 @@ class TestInsertBase:
|
|||
method: update entity field type
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = update_field_type(copy.deepcopy(entity), "int64", DataType.FLOAT)
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_entity), "int64", DataType.FLOAT)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -412,7 +425,7 @@ class TestInsertBase:
|
|||
method: update entity field value
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = update_field_value(copy.deepcopy(entity), DataType.FLOAT, 's')
|
||||
tmp_entity = update_field_value(copy.deepcopy(default_entity), DataType.FLOAT, 's')
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -422,7 +435,7 @@ class TestInsertBase:
|
|||
method: add entity field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = add_field(copy.deepcopy(entity))
|
||||
tmp_entity = add_field(copy.deepcopy(default_entity))
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -432,7 +445,7 @@ class TestInsertBase:
|
|||
method: add entity vector field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = add_vector_field(nb, dim)
|
||||
tmp_entity = add_vector_field(default_nb, default_dim)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -442,7 +455,7 @@ class TestInsertBase:
|
|||
method: remove entity field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = remove_field(copy.deepcopy(entity))
|
||||
tmp_entity = remove_field(copy.deepcopy(default_entity))
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -452,7 +465,7 @@ class TestInsertBase:
|
|||
method: remove entity vector field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = remove_vector_field(copy.deepcopy(entity))
|
||||
tmp_entity = remove_vector_field(copy.deepcopy(default_entity))
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -462,7 +475,7 @@ class TestInsertBase:
|
|||
method: remove entity vector field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = copy.deepcopy(entity)
|
||||
tmp_entity = copy.deepcopy(default_entity)
|
||||
del tmp_entity[-1]["values"]
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
@ -473,7 +486,7 @@ class TestInsertBase:
|
|||
method: remove entity vector field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = copy.deepcopy(entity)
|
||||
tmp_entity = copy.deepcopy(default_entity)
|
||||
del tmp_entity[-1]["type"]
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
@ -484,8 +497,8 @@ class TestInsertBase:
|
|||
method: remove entity vector field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_entity = copy.deepcopy(entity)
|
||||
del tmp_entity[-1]["field"]
|
||||
tmp_entity = copy.deepcopy(default_entity)
|
||||
del tmp_entity[-1]["name"]
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
|
@ -506,7 +519,7 @@ class TestInsertBase:
|
|||
|
||||
def insert(thread_i):
|
||||
logging.getLogger().info("In thread-%d" % thread_i)
|
||||
res_ids = milvus.insert(collection, entities)
|
||||
res_ids = milvus.insert(collection, default_entities)
|
||||
milvus.flush([collection])
|
||||
|
||||
for i in range(thread_num):
|
||||
|
@ -516,7 +529,7 @@ class TestInsertBase:
|
|||
for th in threads:
|
||||
th.join()
|
||||
res_count = milvus.count_entities(collection)
|
||||
assert res_count == thread_num * nb
|
||||
assert res_count == thread_num * default_nb
|
||||
|
||||
# TODO: unable to set config
|
||||
@pytest.mark.level(2)
|
||||
|
@ -526,10 +539,102 @@ class TestInsertBase:
|
|||
method: disable autoflush and insert, get entity
|
||||
expected: the count is equal to 0
|
||||
'''
|
||||
delete_nums = 500
|
||||
disable_flush(connect)
|
||||
ids = connect.insert(collection, entities)
|
||||
res = connect.get_entity_by_id(collection, ids[:500])
|
||||
assert len(res) == 0
|
||||
ids = connect.insert(collection, default_entities)
|
||||
res = connect.get_entity_by_id(collection, ids[:delete_nums])
|
||||
assert len(res) == delete_nums
|
||||
assert res[0] is None
|
||||
|
||||
|
||||
class TestInsertBinary:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_binary_index()
|
||||
)
|
||||
def get_binary_index(self, request):
|
||||
request.param["metric_type"] = "JACCARD"
|
||||
return request.param
|
||||
|
||||
def test_insert_binary_entities(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities in binary collection
|
||||
method: create collection and insert binary entities in it
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush()
|
||||
assert connect.count_entities(binary_collection) == default_nb
|
||||
|
||||
def test_insert_binary_tag(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities and create partition tag
|
||||
method: create collection and insert binary entities in it, with the partition_tag param
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
assert connect.has_partition(binary_collection, default_tag)
|
||||
|
||||
def test_insert_binary_multi_times(self, connect, binary_collection):
|
||||
'''
|
||||
target: test insert entities multi times and final flush
|
||||
method: create collection and insert binary entity multi and final flush
|
||||
expected: the collection row count equals to nb
|
||||
'''
|
||||
for i in range(default_nb):
|
||||
ids = connect.insert(binary_collection, default_binary_entity)
|
||||
assert len(ids) == 1
|
||||
connect.flush([binary_collection])
|
||||
assert connect.count_entities(binary_collection) == default_nb
|
||||
|
||||
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
|
||||
'''
|
||||
target: test insert binary entities after build index
|
||||
method: build index and insert entities
|
||||
expected: no error raised
|
||||
'''
|
||||
connect.create_index(binary_collection, binary_field_name, get_binary_index)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
info = connect.get_collection_info(binary_collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == binary_field_name:
|
||||
assert field["indexes"][0] == get_binary_index
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index):
|
||||
'''
|
||||
target: test build index insert after vector
|
||||
method: insert vector and build index
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_binary_index)
|
||||
info = connect.get_collection_info(binary_collection)
|
||||
fields = info["fields"]
|
||||
for field in fields:
|
||||
if field["name"] == binary_field_name:
|
||||
assert field["indexes"][0] == get_binary_index
|
||||
|
||||
def test_insert_binary_search(self, connect, binary_collection):
|
||||
'''
|
||||
target: test search vector after insert vector after a while
|
||||
method: insert vector, sleep, and search collection
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD")
|
||||
res = connect.search(binary_collection, query)
|
||||
logging.getLogger().debug(res)
|
||||
assert res
|
||||
|
||||
|
||||
class TestInsertAsync:
|
||||
|
@ -628,7 +733,7 @@ class TestInsertAsync:
|
|||
expected: length of ids is equal to the length of vectors
|
||||
'''
|
||||
collection_new = gen_unique_str()
|
||||
future = connect.insert(collection_new, entities, _async=True)
|
||||
future = connect.insert(collection_new, default_entities, _async=True)
|
||||
with pytest.raises(Exception) as e:
|
||||
result = future.result()
|
||||
|
||||
|
@ -671,14 +776,14 @@ class TestInsertMultiCollections:
|
|||
collection_num = 10
|
||||
collection_list = []
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
ids = connect.insert(collection_name, entities)
|
||||
ids = connect.insert(collection_name, default_entities)
|
||||
connect.flush([collection_name])
|
||||
assert len(ids) == nb
|
||||
assert len(ids) == default_nb
|
||||
count = connect.count_entities(collection_name)
|
||||
assert count == nb
|
||||
assert count == default_nb
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
def test_drop_collection_insert_vector_another(self, connect, collection):
|
||||
|
@ -687,10 +792,10 @@ class TestInsertMultiCollections:
|
|||
method: delete collection_2 and insert vector to collection_1
|
||||
expected: row count equals the length of entities inserted
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.drop_collection(collection)
|
||||
ids = connect.insert(collection_name, entity)
|
||||
ids = connect.insert(collection_name, default_entity)
|
||||
connect.flush([collection_name])
|
||||
assert len(ids) == 1
|
||||
|
||||
|
@ -701,10 +806,10 @@ class TestInsertMultiCollections:
|
|||
method: build index and insert vector
|
||||
expected: status ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.drop_collection(collection_name)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
|
@ -714,9 +819,9 @@ class TestInsertMultiCollections:
|
|||
method: build index and insert vector
|
||||
expected: status ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
count = connect.count_entities(collection_name)
|
||||
assert count == 0
|
||||
|
@ -728,9 +833,9 @@ class TestInsertMultiCollections:
|
|||
method: build index and insert vector
|
||||
expected: status ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
count = connect.count_entities(collection)
|
||||
|
@ -743,11 +848,11 @@ class TestInsertMultiCollections:
|
|||
method: search collection and insert vector
|
||||
expected: status ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
res = connect.search(collection, default_single_query)
|
||||
logging.getLogger().debug(res)
|
||||
ids = connect.insert(collection_name, entity)
|
||||
ids = connect.insert(collection_name, default_entity)
|
||||
connect.flush()
|
||||
count = connect.count_entities(collection_name)
|
||||
assert count == 1
|
||||
|
@ -759,9 +864,9 @@ class TestInsertMultiCollections:
|
|||
method: search collection and insert vector
|
||||
expected: status ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
result = connect.search(collection_name, default_single_query)
|
||||
|
||||
@pytest.mark.timeout(ADD_TIMEOUT)
|
||||
|
@ -771,9 +876,9 @@ class TestInsertMultiCollections:
|
|||
method: search collection , sleep, and insert vector
|
||||
expected: status ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
result = connect.search(collection_name, default_single_query)
|
||||
|
||||
|
@ -839,44 +944,44 @@ class TestInsertInvalid(object):
|
|||
expected: raise an exception
|
||||
'''
|
||||
entity_id = get_entity_id
|
||||
ids = [entity_id for _ in range(nb)]
|
||||
ids = [entity_id for _ in range(default_nb)]
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(id_collection, entities, ids)
|
||||
connect.insert(id_collection, default_entities, ids)
|
||||
|
||||
def test_insert_with_invalid_collection_name(self, connect, get_collection_name):
|
||||
collection_name = get_collection_name
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection_name, entity)
|
||||
connect.insert(collection_name, default_entity)
|
||||
|
||||
def test_insert_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
if tag_name is not None:
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, entity, partition_tag=tag_name)
|
||||
connect.insert(collection, default_entity, partition_tag=tag_name)
|
||||
else:
|
||||
connect.insert(collection, entity, partition_tag=tag_name)
|
||||
connect.insert(collection, default_entity, partition_tag=tag_name)
|
||||
|
||||
def test_insert_with_invalid_field_name(self, connect, collection, get_field_name):
|
||||
field_name = get_field_name
|
||||
tmp_entity = update_field_name(copy.deepcopy(entity), "int64", get_field_name)
|
||||
tmp_entity = update_field_name(copy.deepcopy(default_entity), "int64", get_field_name)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
def test_insert_with_invalid_field_type(self, connect, collection, get_field_type):
|
||||
field_type = get_field_type
|
||||
tmp_entity = update_field_type(copy.deepcopy(entity), 'float', field_type)
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'float', field_type)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
def test_insert_with_invalid_field_value(self, connect, collection, get_field_int_value):
|
||||
field_value = get_field_int_value
|
||||
tmp_entity = update_field_type(copy.deepcopy(entity), 'int64', field_value)
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_entity), 'int64', field_value)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(collection, tmp_entity)
|
||||
|
||||
def test_insert_with_invalid_field_vector_value(self, connect, collection, get_field_vectors_value):
|
||||
tmp_entity = copy.deepcopy(entity)
|
||||
tmp_entity = copy.deepcopy(default_entity)
|
||||
src_vector = tmp_entity[-1]["values"]
|
||||
src_vector[0][1] = get_field_vectors_value
|
||||
with pytest.raises(Exception):
|
||||
|
@ -939,21 +1044,19 @@ class TestInsertInvalidBinary(object):
|
|||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
|
||||
field_name = get_field_name
|
||||
tmp_entity = update_field_name(copy.deepcopy(binary_entity), "int64", get_field_name)
|
||||
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
|
||||
field_value = get_field_int_value
|
||||
tmp_entity = update_field_type(copy.deepcopy(binary_entity), 'int64', field_value)
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value):
|
||||
tmp_entity = copy.deepcopy(binary_entity)
|
||||
tmp_entity = copy.deepcopy(default_binary_entity)
|
||||
src_vector = tmp_entity[-1]["values"]
|
||||
src_vector[0][1] = get_field_vectors_value
|
||||
with pytest.raises(Exception):
|
||||
|
@ -967,34 +1070,20 @@ class TestInsertInvalidBinary(object):
|
|||
expected: raise an exception
|
||||
'''
|
||||
entity_id = get_entity_id
|
||||
ids = [entity_id for _ in range(nb)]
|
||||
ids = [entity_id for _ in range(default_nb)]
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_id_collection, binary_entities, ids)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
|
||||
field_name = get_field_name
|
||||
tmp_entity = update_field_name(copy.deepcopy(binary_entity), "int64", get_field_name)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entity)
|
||||
connect.insert(binary_id_collection, default_binary_entities, ids)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_type(self, connect, binary_collection, get_field_type):
|
||||
field_type = get_field_type
|
||||
tmp_entity = update_field_type(copy.deepcopy(binary_entity), 'int64', field_type)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
|
||||
field_value = get_field_int_value
|
||||
tmp_entity = update_field_type(copy.deepcopy(binary_entity), 'int64', field_value)
|
||||
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_type)
|
||||
with pytest.raises(Exception):
|
||||
connect.insert(binary_collection, tmp_entity)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value):
|
||||
tmp_entity = copy.deepcopy(binary_entities)
|
||||
tmp_entity = copy.deepcopy(default_binary_entities)
|
||||
src_vector = tmp_entity[-1]["values"]
|
||||
src_vector[1] = get_field_vectors_value
|
||||
with pytest.raises(Exception):
|
||||
|
|
|
@ -6,20 +6,9 @@ import logging
|
|||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 100000
|
||||
nb = 6000
|
||||
tag = "1970_01_01"
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
collection_id = "list_id_in_segment"
|
||||
entity = gen_entities(1)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
uid = "list_id_in_segment"
|
||||
|
||||
def get_segment_id(connect, collection, nb=1, vec_type='float', index_params=None):
|
||||
if vec_type != "float":
|
||||
|
@ -30,9 +19,9 @@ def get_segment_id(connect, collection, nb=1, vec_type='float', index_params=Non
|
|||
connect.flush([collection])
|
||||
if index_params:
|
||||
if vec_type == 'float':
|
||||
connect.create_index(collection, field_name, index_params)
|
||||
connect.create_index(collection, default_float_vec_field_name, index_params)
|
||||
else:
|
||||
connect.create_index(collection, binary_field_name, index_params)
|
||||
connect.create_index(collection, default_binary_vec_field_name, index_params)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
return ids, stats["partitions"][0]["segments"][0]["id"]
|
||||
|
||||
|
@ -61,7 +50,7 @@ class TestListIdInSegmentBase:
|
|||
method: call list_id_in_segment with a random collection_name, which is not in db
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
ids, segment_id = get_segment_id(connect, collection)
|
||||
with pytest.raises(Exception) as e:
|
||||
vector_ids = connect.list_id_in_segment(collection_name, segment_id)
|
||||
|
@ -73,6 +62,7 @@ class TestListIdInSegmentBase:
|
|||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_id_in_segment_collection_name_invalid(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: get vector ids where collection name is invalid
|
||||
|
@ -102,7 +92,7 @@ class TestListIdInSegmentBase:
|
|||
expected: status not ok
|
||||
'''
|
||||
ids, seg_id = get_segment_id(connect, collection)
|
||||
# segment = gen_unique_str(collection_id)
|
||||
# segment = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
vector_ids = connect.list_id_in_segment(collection, seg_id + 10000)
|
||||
|
||||
|
@ -129,11 +119,11 @@ class TestListIdInSegmentBase:
|
|||
'''
|
||||
nb = 10
|
||||
entities = gen_entities(nb)
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["partitions"][1]["tag"] == tag
|
||||
assert stats["partitions"][1]["tag"] == default_tag
|
||||
vector_ids = connect.list_id_in_segment(collection, stats["partitions"][1]["segments"][0]["id"])
|
||||
# vector_ids should match ids
|
||||
assert len(vector_ids) == nb
|
||||
|
@ -157,7 +147,7 @@ class TestListIdInSegmentBase:
|
|||
method: call list_id_in_segment and check if the segment contains vectors
|
||||
expected: status ok
|
||||
'''
|
||||
ids, seg_id = get_segment_id(connect, collection, nb=nb, index_params=get_simple_index)
|
||||
ids, seg_id = get_segment_id(connect, collection, nb=default_nb, index_params=get_simple_index)
|
||||
try:
|
||||
connect.list_id_in_segment(collection, seg_id)
|
||||
except Exception as e:
|
||||
|
@ -171,11 +161,11 @@ class TestListIdInSegmentBase:
|
|||
method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
assert stats["partitions"][1]["tag"] == tag
|
||||
assert stats["partitions"][1]["tag"] == default_tag
|
||||
try:
|
||||
connect.list_id_in_segment(collection, stats["partitions"][1]["segments"][0]["id"])
|
||||
except Exception as e:
|
||||
|
@ -183,7 +173,6 @@ class TestListIdInSegmentBase:
|
|||
# vector_ids should match ids
|
||||
# TODO
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_id_in_segment_after_delete_vectors(self, connect, collection):
|
||||
'''
|
||||
target: get vector ids after vectors are deleted
|
||||
|
@ -200,6 +189,24 @@ class TestListIdInSegmentBase:
|
|||
assert len(vector_ids) == 1
|
||||
assert vector_ids[0] == ids[1]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_id_in_segment_after_delete_vectors(self, connect, collection):
|
||||
'''
|
||||
target: get vector ids after vectors are deleted
|
||||
method: add vectors and delete a few, call list_id_in_segment
|
||||
expected: vector_ids decreased after vectors deleted
|
||||
'''
|
||||
nb = 60
|
||||
delete_length = 10
|
||||
ids, seg_id = get_segment_id(connect, collection, nb=nb)
|
||||
delete_ids = ids[:delete_length]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
connect.flush([collection])
|
||||
stats = connect.get_collection_stats(collection)
|
||||
vector_ids = connect.list_id_in_segment(collection, stats["partitions"][0]["segments"][0]["id"])
|
||||
assert len(vector_ids) == nb - delete_length
|
||||
assert vector_ids[0] == ids[delete_length]
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_id_in_segment_with_index_ip(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
@ -208,11 +215,11 @@ class TestListIdInSegmentBase:
|
|||
expected: ids returned in ids inserted
|
||||
'''
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
ids, seg_id = get_segment_id(connect, collection, nb=nb, index_params=get_simple_index)
|
||||
ids, seg_id = get_segment_id(connect, collection, nb=default_nb, index_params=get_simple_index)
|
||||
vector_ids = connect.list_id_in_segment(collection, seg_id)
|
||||
# TODO:
|
||||
segment_row_count = connect.get_collection_info(collection)["segment_row_limit"]
|
||||
assert vector_ids[0:segment_row_count] == ids[0:segment_row_count]
|
||||
segment_row_limit = connect.get_collection_info(collection)["segment_row_limit"]
|
||||
assert vector_ids[0:segment_row_limit] == ids[0:segment_row_limit]
|
||||
|
||||
class TestListIdInSegmentBinary:
|
||||
"""
|
||||
|
@ -245,10 +252,10 @@ class TestListIdInSegmentBinary:
|
|||
method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(binary_collection, tag)
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
nb = 10
|
||||
vectors, entities = gen_binary_entities(nb)
|
||||
ids = connect.insert(binary_collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(binary_collection, entities, partition_tag=default_tag)
|
||||
connect.flush([binary_collection])
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][1]["segments"][0]["id"])
|
||||
|
@ -275,7 +282,7 @@ class TestListIdInSegmentBinary:
|
|||
method: call list_id_in_segment and check if the segment contains vectors
|
||||
expected: status ok
|
||||
'''
|
||||
ids, seg_id = get_segment_id(connect, binary_collection, nb=nb, index_params=get_jaccard_index, vec_type='binary')
|
||||
ids, seg_id = get_segment_id(connect, binary_collection, nb=default_nb, index_params=get_jaccard_index, vec_type='binary')
|
||||
vector_ids = connect.list_id_in_segment(binary_collection, seg_id)
|
||||
# TODO:
|
||||
|
||||
|
@ -285,11 +292,11 @@ class TestListIdInSegmentBinary:
|
|||
method: create partition, add vectors to it and call list_id_in_segment, check if the segment contains vectors
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(binary_collection, tag)
|
||||
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.flush([binary_collection])
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["partitions"][1]["tag"] == tag
|
||||
assert stats["partitions"][1]["tag"] == default_tag
|
||||
vector_ids = connect.list_id_in_segment(binary_collection, stats["partitions"][1]["segments"][0]["id"])
|
||||
# vector_ids should match ids
|
||||
# TODO
|
||||
|
|
|
@ -9,36 +9,27 @@ import numpy as np
|
|||
|
||||
from milvus import DataType
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
top_k_limit = 16384
|
||||
collection_id = "search"
|
||||
tag = "1970_01_01"
|
||||
insert_interval_time = 1.5
|
||||
nb = 6000
|
||||
top_k = 10
|
||||
uid = "test_search"
|
||||
nq = 1
|
||||
nprobe = 1
|
||||
epsilon = 0.001
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
default_fields = gen_default_fields()
|
||||
search_param = {"nprobe": 1}
|
||||
|
||||
entity = gen_entities(1, is_normal=True)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb, is_normal=True)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
|
||||
entities = gen_entities(default_nb, is_normal=True)
|
||||
raw_vectors, binary_entities = gen_binary_entities(default_nb)
|
||||
default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq)
|
||||
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k, nq)
|
||||
|
||||
|
||||
def init_data(connect, collection, nb=6000, partition_tags=None, auto_id=True):
|
||||
def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True):
|
||||
'''
|
||||
Generate entities and add it in collection
|
||||
'''
|
||||
global entities
|
||||
if nb == 6000:
|
||||
if nb == 1200:
|
||||
insert_entities = entities
|
||||
else:
|
||||
insert_entities = gen_entities(nb, is_normal=True)
|
||||
|
@ -56,14 +47,14 @@ def init_data(connect, collection, nb=6000, partition_tags=None, auto_id=True):
|
|||
return insert_entities, ids
|
||||
|
||||
|
||||
def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None):
|
||||
def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None):
|
||||
'''
|
||||
Generate entities and add it in collection
|
||||
'''
|
||||
ids = []
|
||||
global binary_entities
|
||||
global raw_vectors
|
||||
if nb == 6000:
|
||||
if nb == 1200:
|
||||
insert_entities = binary_entities
|
||||
insert_raw_vectors = raw_vectors
|
||||
else:
|
||||
|
@ -141,7 +132,7 @@ class TestSearchBase:
|
|||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[1, 10, 16385]
|
||||
params=[1, 10]
|
||||
)
|
||||
def get_top_k(self, request):
|
||||
yield request.param
|
||||
|
@ -155,7 +146,7 @@ class TestSearchBase:
|
|||
|
||||
def test_search_flat(self, connect, collection, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
target: test basic search function, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -163,7 +154,26 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
entities, ids = init_data(connect, collection)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
if top_k <= top_k_limit:
|
||||
if top_k <= max_top_k:
|
||||
res = connect.search(collection, query)
|
||||
assert len(res[0]) == top_k
|
||||
assert res[0]._distances[0] <= epsilon
|
||||
assert check_id_result(res[0], ids[0])
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
def test_search_flat_top_k(self, connect, collection, get_nq):
|
||||
'''
|
||||
target: test basic search function, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
top_k = 16385
|
||||
nq = get_nq
|
||||
entities, ids = init_data(connect, collection)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
if top_k <= max_top_k:
|
||||
res = connect.search(collection, query)
|
||||
assert len(res[0]) == top_k
|
||||
assert res[0]._distances[0] <= epsilon
|
||||
|
@ -174,7 +184,7 @@ class TestSearchBase:
|
|||
|
||||
def test_search_field(self, connect, collection, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
target: test basic search function, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -182,7 +192,7 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
entities, ids = init_data(connect, collection)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
if top_k <= top_k_limit:
|
||||
if top_k <= max_top_k:
|
||||
res = connect.search(collection, query, fields=["float_vector"])
|
||||
assert len(res[0]) == top_k
|
||||
assert res[0]._distances[0] <= epsilon
|
||||
|
@ -198,7 +208,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -206,13 +216,13 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -233,16 +243,16 @@ class TestSearchBase:
|
|||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=search_metric_type,
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type,
|
||||
search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: add vectors into collection, search with the given vectors, check the result
|
||||
expected: the length of the result is top_k, search collection with partition tag return empty
|
||||
'''
|
||||
|
@ -250,14 +260,14 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -266,13 +276,13 @@ class TestSearchBase:
|
|||
assert len(res[0]) >= top_k
|
||||
assert res[0]._distances[0] < epsilon
|
||||
assert check_id_result(res[0], ids[0])
|
||||
res = connect.search(collection, query, partition_tags=[tag])
|
||||
res = connect.search(collection, query, partition_tags=[default_tag])
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -280,15 +290,15 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
connect.create_partition(collection, tag)
|
||||
entities, ids = init_data(connect, collection, partition_tags=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
entities, ids = init_data(connect, collection, partition_tags=default_tag)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
for tags in [[tag], [tag, "new_tag"]]:
|
||||
if top_k > top_k_limit:
|
||||
for tags in [[default_tag], [default_tag, "new_tag"]]:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query, partition_tags=tags)
|
||||
else:
|
||||
|
@ -301,7 +311,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search with the given vectors and tag (tag name not existed in collection), check the result
|
||||
expected: error raised
|
||||
'''
|
||||
|
@ -309,7 +319,7 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
entities, ids = init_data(connect, collection)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query, partition_tags=["new_tag"])
|
||||
else:
|
||||
|
@ -320,7 +330,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search collection with the given vectors and tags, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -328,16 +338,16 @@ class TestSearchBase:
|
|||
nq = 2
|
||||
new_tag = "new_tag"
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
entities, ids = init_data(connect, collection, partition_tags=tag)
|
||||
entities, ids = init_data(connect, collection, partition_tags=default_tag)
|
||||
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -353,7 +363,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search collection with the given vectors and tags, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -362,7 +372,7 @@ class TestSearchBase:
|
|||
tag = "tag"
|
||||
new_tag = "new_tag"
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
|
@ -371,7 +381,7 @@ class TestSearchBase:
|
|||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -389,7 +399,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
target: test basic search function, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -397,7 +407,7 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
entities, ids = init_data(connect, collection)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP")
|
||||
if top_k <= top_k_limit:
|
||||
if top_k <= max_top_k:
|
||||
res = connect.search(collection, query)
|
||||
assert len(res[0]) == top_k
|
||||
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
|
||||
|
@ -409,7 +419,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search with the given vectors, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -417,14 +427,14 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
entities, ids = init_data(connect, collection)
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -437,7 +447,7 @@ class TestSearchBase:
|
|||
@pytest.mark.level(2)
|
||||
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: add vectors into collection, search with the given vectors, check the result
|
||||
expected: the length of the result is top_k, search collection with partition tag return empty
|
||||
'''
|
||||
|
@ -445,16 +455,16 @@ class TestSearchBase:
|
|||
nq = get_nq
|
||||
metric_type = "IP"
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
entities, ids = init_data(connect, collection)
|
||||
get_simple_index["metric_type"] = metric_type
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type,
|
||||
search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -463,13 +473,13 @@ class TestSearchBase:
|
|||
assert len(res[0]) >= top_k
|
||||
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
|
||||
assert check_id_result(res[0], ids[0])
|
||||
res = connect.search(collection, query, partition_tags=[tag])
|
||||
res = connect.search(collection, query, partition_tags=[default_tag])
|
||||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
target: test basic search function, all the search params is corrent, test all index params, and build
|
||||
method: search collection with the given vectors and tags, check the result
|
||||
expected: the length of the result is top_k
|
||||
'''
|
||||
|
@ -478,17 +488,17 @@ class TestSearchBase:
|
|||
metric_type = "IP"
|
||||
new_tag = "new_tag"
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type == "IVF_PQ":
|
||||
if index_type in skip_pq():
|
||||
pytest.skip("Skip PQ")
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, new_tag)
|
||||
entities, ids = init_data(connect, collection, partition_tags=tag)
|
||||
entities, ids = init_data(connect, collection, partition_tags=default_tag)
|
||||
new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
|
||||
get_simple_index["metric_type"] = metric_type
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
|
||||
if top_k > top_k_limit:
|
||||
if top_k > max_top_k:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
else:
|
||||
|
@ -518,7 +528,7 @@ class TestSearchBase:
|
|||
method: search with the random collection_name, which is not in db
|
||||
expected: status not ok
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection_name, default_query)
|
||||
|
||||
|
@ -531,8 +541,8 @@ class TestSearchBase:
|
|||
nq = 2
|
||||
search_param = {"nprobe": 1}
|
||||
entities, ids = init_data(connect, collection, nb=nq)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
|
||||
inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param)
|
||||
inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param)
|
||||
distance_0 = l2(vecs[0], inside_vecs[0])
|
||||
distance_1 = l2(vecs[0], inside_vecs[1])
|
||||
res = connect.search(collection, query)
|
||||
|
@ -549,11 +559,11 @@ class TestSearchBase:
|
|||
entities, ids = init_data(connect, id_collection, auto_id=False)
|
||||
connect.create_index(id_collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param)
|
||||
inside_vecs = entities[-1]["values"]
|
||||
min_distance = 1.0
|
||||
min_id = None
|
||||
for i in range(nb):
|
||||
for i in range(default_nb):
|
||||
tmp_dis = l2(vecs[0], inside_vecs[i])
|
||||
if min_distance > tmp_dis:
|
||||
min_distance = tmp_dis
|
||||
|
@ -577,9 +587,9 @@ class TestSearchBase:
|
|||
metirc_type = "IP"
|
||||
search_param = {"nprobe": 1}
|
||||
entities, ids = init_data(connect, collection, nb=nq)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type,
|
||||
search_params=search_param)
|
||||
inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param)
|
||||
distance_0 = ip(vecs[0], inside_vecs[0])
|
||||
distance_1 = ip(vecs[0], inside_vecs[1])
|
||||
res = connect.search(collection, query)
|
||||
|
@ -598,12 +608,12 @@ class TestSearchBase:
|
|||
get_simple_index["metric_type"] = metirc_type
|
||||
connect.create_index(id_collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type,
|
||||
search_params=search_param)
|
||||
inside_vecs = entities[-1]["values"]
|
||||
max_distance = 0
|
||||
max_id = None
|
||||
for i in range(nb):
|
||||
for i in range(default_nb):
|
||||
tmp_dis = ip(vecs[0], inside_vecs[i])
|
||||
if max_distance < tmp_dis:
|
||||
max_distance = tmp_dis
|
||||
|
@ -627,7 +637,7 @@ class TestSearchBase:
|
|||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="JACCARD")
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD")
|
||||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
|
@ -643,7 +653,7 @@ class TestSearchBase:
|
|||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="L2")
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2")
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(binary_collection, query)
|
||||
|
||||
|
@ -659,7 +669,7 @@ class TestSearchBase:
|
|||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
distance_0 = hamming(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = hamming(query_int_vectors[0], int_vectors[1])
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="HAMMING")
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING")
|
||||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
|
||||
|
||||
|
@ -675,7 +685,7 @@ class TestSearchBase:
|
|||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
distance_0 = substructure(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = substructure(query_int_vectors[0], int_vectors[1])
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="SUBSTRUCTURE")
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUBSTRUCTURE")
|
||||
res = connect.search(binary_collection, query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
|
@ -708,7 +718,7 @@ class TestSearchBase:
|
|||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="SUPERSTRUCTURE")
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUPERSTRUCTURE")
|
||||
res = connect.search(binary_collection, query)
|
||||
assert len(res[0]) == 0
|
||||
|
||||
|
@ -743,7 +753,7 @@ class TestSearchBase:
|
|||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
|
||||
distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, metric_type="TANIMOTO")
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO")
|
||||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
|
@ -759,7 +769,7 @@ class TestSearchBase:
|
|||
top_k = 10
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection = gen_unique_str(collection_id)
|
||||
collection = gen_unique_str(uid)
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
# create collection
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
|
@ -793,7 +803,7 @@ class TestSearchBase:
|
|||
top_k = 10
|
||||
threads_num = 4
|
||||
threads = []
|
||||
collection = gen_unique_str(collection_id)
|
||||
collection = gen_unique_str(uid)
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
# create collection
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
|
@ -825,10 +835,10 @@ class TestSearchBase:
|
|||
top_k = 10
|
||||
nq = 20
|
||||
for i in range(num):
|
||||
collection = gen_unique_str(collection_id + str(i))
|
||||
collection = gen_unique_str(uid + str(i))
|
||||
connect.create_collection(collection, default_fields)
|
||||
entities, ids = init_data(connect, collection)
|
||||
assert len(ids) == nb
|
||||
assert len(ids) == default_nb
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
|
@ -885,7 +895,7 @@ class TestSearchDSL(object):
|
|||
entities, ids = init_data(connect, collection)
|
||||
res = connect.search(collection, default_query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
def test_query_wrong_format(self, connect, collection):
|
||||
'''
|
||||
|
@ -971,7 +981,7 @@ class TestSearchDSL(object):
|
|||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
# TODO:
|
||||
|
||||
def test_query_term_values_parts_in(self, connect, collection):
|
||||
|
@ -981,11 +991,11 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
expr = {"must": [gen_default_vector_expr(default_query),
|
||||
gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])]}
|
||||
gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
# TODO:
|
||||
|
||||
# TODO:
|
||||
|
@ -997,7 +1007,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
expr = {
|
||||
"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, nb)])]}
|
||||
"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, default_nb)])]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
|
@ -1029,7 +1039,7 @@ class TestSearchDSL(object):
|
|||
expected: Exception raised
|
||||
'''
|
||||
expr = {"must": [gen_default_vector_expr(default_query),
|
||||
gen_default_term_expr(keyword="terrm", values=[i for i in range(nb // 2)])]}
|
||||
gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
@ -1066,17 +1076,17 @@ class TestSearchDSL(object):
|
|||
connect.create_collection(collection_term, term_fields)
|
||||
term_entities = add_field(entities, field_name="term")
|
||||
ids = connect.insert(collection_term, term_entities)
|
||||
assert len(ids) == nb
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection_term])
|
||||
count = connect.count_entities(collection_term)
|
||||
assert count == nb
|
||||
term_param = {"term": {"term": {"values": [i for i in range(nb // 2)]}}}
|
||||
assert count == default_nb
|
||||
term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}}
|
||||
expr = {"must": [gen_default_vector_expr(default_query),
|
||||
term_param]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection_term, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
connect.drop_collection(collection_term)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1153,7 +1163,7 @@ class TestSearchDSL(object):
|
|||
expected: 0
|
||||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
ranges = {"GT": nb, "LT": 0}
|
||||
ranges = {"GT": default_nb, "LT": 0}
|
||||
range = gen_default_range_expr(ranges=ranges)
|
||||
expr = {"must": [gen_default_vector_expr(default_query), range]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
|
@ -1180,7 +1190,7 @@ class TestSearchDSL(object):
|
|||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
def test_query_range_one_field_not_existed(self, connect, collection):
|
||||
'''
|
||||
|
@ -1189,7 +1199,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
range = gen_default_range_expr()
|
||||
range["range"].update({"a": {"GT": 1, "LT": nb // 2}})
|
||||
range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}})
|
||||
expr = {"must": [gen_default_vector_expr(default_query), range]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
with pytest.raises(Exception) as e:
|
||||
|
@ -1210,12 +1220,12 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
term_first = gen_default_term_expr()
|
||||
term_second = gen_default_term_expr(values=[i for i in range(nb // 3)])
|
||||
term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)])
|
||||
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# TODO
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1226,7 +1236,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
term_first = gen_default_term_expr()
|
||||
term_second = gen_default_term_expr(values=[i for i in range(nb // 2, nb + nb // 2)])
|
||||
term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])
|
||||
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
|
@ -1241,7 +1251,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
term_first = gen_default_term_expr()
|
||||
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(nb // 2, nb)])
|
||||
term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(default_nb // 2, default_nb)])
|
||||
expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
|
@ -1256,8 +1266,8 @@ class TestSearchDSL(object):
|
|||
expected: pass
|
||||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
term_first = {"int64": {"values": [i for i in range(nb // 2)]}}
|
||||
term_second = {"float": {"values": [float(i) for i in range(nb // 2, nb)]}}
|
||||
term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}}
|
||||
term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}}
|
||||
term = update_term_expr({"term": {}}, [term_first, term_second])
|
||||
expr = {"must": [gen_default_vector_expr(default_query), term]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
|
@ -1273,12 +1283,12 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
range_one = gen_default_range_expr()
|
||||
range_two = gen_default_range_expr(ranges={"GT": 1, "LT": nb // 3})
|
||||
range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3})
|
||||
expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# TODO
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1289,7 +1299,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
range_one = gen_default_range_expr()
|
||||
range_two = gen_default_range_expr(ranges={"GT": nb // 2, "LT": nb})
|
||||
range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
|
||||
expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
|
@ -1305,7 +1315,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
range_first = gen_default_range_expr()
|
||||
range_second = gen_default_range_expr(field="float", ranges={"GT": nb // 2, "LT": nb})
|
||||
range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb})
|
||||
expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
|
@ -1320,12 +1330,11 @@ class TestSearchDSL(object):
|
|||
expected: pass
|
||||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
range_first = {"int64": {"GT": 0, "LT": nb // 2}}
|
||||
range_second = {"float": {"GT": nb / 2, "LT": float(nb)}}
|
||||
range_first = {"int64": {"GT": 0, "LT": default_nb // 2}}
|
||||
range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}}
|
||||
range = update_range_expr({"range": {}}, [range_first, range_second])
|
||||
expr = {"must": [gen_default_vector_expr(default_query), range]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
|
@ -1344,12 +1353,12 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
term = gen_default_term_expr()
|
||||
range = gen_default_range_expr(ranges={"GT": -1, "LT": nb // 2})
|
||||
range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2})
|
||||
expr = {"must": [gen_default_vector_expr(default_query), term, range]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert len(res[0]) == default_top_k
|
||||
|
||||
# TODO
|
||||
def test_query_single_term_range_no_common(self, connect, collection):
|
||||
|
@ -1359,7 +1368,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
term = gen_default_term_expr()
|
||||
range = gen_default_range_expr(ranges={"GT": nb // 2, "LT": nb})
|
||||
range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
|
||||
expr = {"must": [gen_default_vector_expr(default_query), term, range]}
|
||||
query = update_query_expr(default_query, expr=expr)
|
||||
res = connect.search(collection, query)
|
||||
|
@ -1380,7 +1389,7 @@ class TestSearchDSL(object):
|
|||
'''
|
||||
entities, ids = init_data(connect, collection)
|
||||
vector1 = default_query
|
||||
vector2 = gen_query_vectors(field_name, entities, top_k, nq=2)
|
||||
vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2)
|
||||
expr = {
|
||||
"must": [vector1, vector2]
|
||||
}
|
||||
|
@ -1509,8 +1518,9 @@ class TestSearchInvalid(object):
|
|||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection_name, default_query)
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_search_with_invalid_tag(self, connect, collection):
|
||||
# TODO(yukun)
|
||||
@pytest.mark.level(2)
|
||||
def _test_search_with_invalid_tag(self, connect, collection):
|
||||
tag = " "
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, default_query, partition_tags=tag)
|
||||
|
@ -1541,7 +1551,7 @@ class TestSearchInvalid(object):
|
|||
@pytest.mark.level(1)
|
||||
def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
|
||||
'''
|
||||
target: test search fuction, with the wrong top_k
|
||||
target: test search function, with the wrong top_k
|
||||
method: search with top_k
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
|
@ -1564,7 +1574,7 @@ class TestSearchInvalid(object):
|
|||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
|
||||
'''
|
||||
target: test search fuction, with the wrong nprobe
|
||||
target: test search function, with the wrong nprobe
|
||||
method: search with nprobe
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
|
@ -1572,16 +1582,18 @@ class TestSearchInvalid(object):
|
|||
index_type = get_simple_index["index_type"]
|
||||
if index_type in ["FLAT"]:
|
||||
pytest.skip("skip in FLAT index")
|
||||
if index_type != search_params["index_type"]:
|
||||
pytest.skip("skip if index_type not matched")
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params=search_params["search_params"])
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_params_binary(self, connect, binary_collection):
|
||||
'''
|
||||
target: test search fuction, with the wrong nprobe
|
||||
target: test search function, with the wrong nprobe
|
||||
method: search with nprobe
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
|
@ -1589,15 +1601,15 @@ class TestSearchInvalid(object):
|
|||
index_type = "BIN_IVF_FLAT"
|
||||
int_vectors, entities, ids = init_binary_data(connect, binary_collection)
|
||||
query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
|
||||
connect.create_index(binary_collection, binary_field_name, {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 1024}})
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, top_k, nq, search_params={"nprobe": 0}, metric_type="JACCARD")
|
||||
connect.create_index(binary_collection, binary_field_name, {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}})
|
||||
query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, search_params={"nprobe": 0}, metric_type="JACCARD")
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(binary_collection, query)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
|
||||
'''
|
||||
target: test search fuction, with empty search params
|
||||
target: test search function, with empty search params
|
||||
method: search with params
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
|
@ -1608,7 +1620,7 @@ class TestSearchInvalid(object):
|
|||
pytest.skip("skip in FLAT index")
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={})
|
||||
query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={})
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.search(collection, query)
|
||||
|
||||
|
|
|
@ -7,16 +7,7 @@ from multiprocessing import Pool, Process
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
collection_id = "mysql_failure"
|
||||
nprobe = 1
|
||||
tag = "1970_01_01"
|
||||
|
||||
|
||||
class TestMysql:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test mysql failure
|
||||
|
@ -33,7 +24,7 @@ class TestMysql:
|
|||
big_nb = 20000
|
||||
index_param = {"nlist": 1024, "m": 16}
|
||||
index_type = IndexType.IVF_PQ
|
||||
vectors = gen_vectors(big_nb, dim)
|
||||
vectors = gen_vectors(big_nb, default_dim)
|
||||
status, ids = connect.insert(collection, vectors, ids=[i for i in range(big_nb)])
|
||||
status = connect.flush([collection])
|
||||
assert status.OK()
|
||||
|
|
|
@ -3,131 +3,313 @@ import random
|
|||
import pdb
|
||||
import threading
|
||||
import logging
|
||||
import json
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
collection_id = "test_partition_restart"
|
||||
nprobe = 1
|
||||
tag = "1970_01_01"
|
||||
uid = "wal"
|
||||
TIMEOUT = 120
|
||||
insert_interval_time = 1.5
|
||||
big_nb = 100000
|
||||
field_name = "float_vector"
|
||||
big_entities = gen_entities(big_nb)
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
|
||||
|
||||
class TestRestartBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_partition` function
|
||||
******************************************************************
|
||||
"""
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def skip_check(self, connect, args):
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def skip_check(self, args):
|
||||
logging.getLogger().info(args)
|
||||
if "service_name" not in args or not args["service_name"]:
|
||||
reason = "Skip if service name not provided"
|
||||
logging.getLogger().info(reason)
|
||||
pytest.skip(reason)
|
||||
if args["service_name"].find("shards") != -1:
|
||||
reason = "Skip restart cases in shards mode"
|
||||
logging.getLogger().info(reason)
|
||||
pytest.skip(reason)
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_create_partition_insert_restart(self, connect, collection, args):
|
||||
def _test_insert_flush(self, connect, collection, args):
|
||||
'''
|
||||
target: return the same row count after server restart
|
||||
method: call function: create partition, then insert, restart server and assert row count
|
||||
expected: status ok, and row count keep the same
|
||||
method: call function: create collection, then insert/flush, restart server and assert row count
|
||||
expected: row count keep the same
|
||||
'''
|
||||
status = connect.create_partition(collection, tag)
|
||||
assert status.OK()
|
||||
nq = 1000
|
||||
vectors = gen_vectors(nq, dim)
|
||||
ids = [i for i in range(nq)]
|
||||
status, ids = connect.insert(collection, vectors, ids, partition_tag=tag)
|
||||
assert status.OK()
|
||||
status = connect.flush([collection])
|
||||
assert status.OK()
|
||||
status, res = connect.count_entities(collection)
|
||||
logging.getLogger().info(res)
|
||||
assert res == nq
|
||||
|
||||
# restart server
|
||||
if restart_server(args["service_name"]):
|
||||
logging.getLogger().info("Restart success")
|
||||
else:
|
||||
logging.getLogger().info("Restart failed")
|
||||
# assert row count again
|
||||
|
||||
# debug
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
status, res = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(res)
|
||||
assert status.OK()
|
||||
assert res == nq
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_during_creating_index_restart(self, connect, collection, args):
|
||||
'''
|
||||
target: return the same row count after server restart
|
||||
method: call function: insert, flush, and create index, server do restart during creating index
|
||||
expected: row count, vector-id, index info keep the same
|
||||
'''
|
||||
# reset auto_flush_interval
|
||||
# auto_flush_interval = 100
|
||||
get_ids_length = 500
|
||||
timeout = 60
|
||||
big_nb = 20000
|
||||
index_param = {"nlist": 1024, "m": 16}
|
||||
index_type = IndexType.IVF_PQ
|
||||
# status, res_set = connect.set_config("db_config", "auto_flush_interval", auto_flush_interval)
|
||||
# assert status.OK()
|
||||
# status, res_get = connect.get_config("db_config", "auto_flush_interval")
|
||||
# assert status.OK()
|
||||
# assert res_get == str(auto_flush_interval)
|
||||
# insert and create index
|
||||
vectors = gen_vectors(big_nb, dim)
|
||||
status, ids = connect.insert(collection, vectors, ids=[i for i in range(big_nb)])
|
||||
status = connect.flush([collection])
|
||||
assert status.OK()
|
||||
status, res_count = connect.count_entities(collection)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
assert status.OK()
|
||||
assert res_count == big_nb
|
||||
logging.getLogger().info("Start create index async")
|
||||
status = connect.create_index(collection, index_type, index_param, _async=True)
|
||||
time.sleep(2)
|
||||
assert res_count == 2 * nb
|
||||
# restart server
|
||||
logging.getLogger().info("Before restart server")
|
||||
if restart_server(args["service_name"]):
|
||||
logging.getLogger().info("Restart success")
|
||||
else:
|
||||
logging.getLogger().info("Restart failed")
|
||||
# check row count, index_type, vertor-id after server restart
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
status, res_count = new_connect.count_entities(collection)
|
||||
assert status.OK()
|
||||
assert res_count == big_nb
|
||||
status, res_info = new_connect.get_index_info(collection)
|
||||
logging.getLogger().info(res_info)
|
||||
assert res_info._params == index_param
|
||||
assert res_info._collection_name == collection
|
||||
assert res_info._index_type == index_type
|
||||
logging.getLogger().info("Start restart server")
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
assert res_count == 2 * nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_insert_during_flushing(self, connect, collection, args):
|
||||
'''
|
||||
target: flushing will recover
|
||||
method: call function: create collection, then insert/flushing, restart server and assert row count
|
||||
expected: row count equals 0
|
||||
'''
|
||||
# disable_autoflush()
|
||||
ids = connect.insert(collection, big_entities)
|
||||
connect.flush([collection], _async=True)
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
if res_count < big_nb:
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count_2 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_2)
|
||||
timeout = 300
|
||||
start_time = time.time()
|
||||
while new_connect.count_entities(collection) != big_nb and (time.time() - start_time < timeout):
|
||||
time.sleep(10)
|
||||
logging.getLogger().info(new_connect.count_entities(collection))
|
||||
res_count_3 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_3)
|
||||
assert res_count_3 == big_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_delete_during_flushing(self, connect, collection, args):
|
||||
'''
|
||||
target: flushing will recover
|
||||
method: call function: create collection, then delete/flushing, restart server and assert row count
|
||||
expected: row count equals (nb - delete_length)
|
||||
'''
|
||||
# disable_autoflush()
|
||||
ids = connect.insert(collection, big_entities)
|
||||
connect.flush([collection])
|
||||
delete_length = 1000
|
||||
delete_ids = ids[big_nb//4:big_nb//4+delete_length]
|
||||
delete_res = connect.delete_entity_by_id(collection, delete_ids)
|
||||
connect.flush([collection], _async=True)
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count_2 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_2)
|
||||
timeout = 100
|
||||
start_time = time.time()
|
||||
i = 1
|
||||
while time.time() - start_time < timeout:
|
||||
stauts, stats = new_connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(i)
|
||||
logging.getLogger().info(stats["partitions"])
|
||||
index_name = stats["partitions"][0]["segments"][0]["index_name"]
|
||||
if index_name == "PQ":
|
||||
break
|
||||
time.sleep(4)
|
||||
i += 1
|
||||
if time.time() - start_time >= timeout:
|
||||
logging.getLogger().info("Timeout")
|
||||
assert False
|
||||
get_ids = random.sample(ids, get_ids_length)
|
||||
status, res = new_connect.get_entity_by_id(collection, get_ids)
|
||||
while new_connect.count_entities(collection) != big_nb - delete_length and (time.time() - start_time < timeout):
|
||||
time.sleep(10)
|
||||
logging.getLogger().info(new_connect.count_entities(collection))
|
||||
if new_connect.count_entities(collection) == big_nb - delete_length:
|
||||
time.sleep(10)
|
||||
res_count_3 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_3)
|
||||
assert res_count_3 == big_nb - delete_length
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_during_indexed(self, connect, collection, args):
|
||||
'''
|
||||
target: flushing will recover
|
||||
method: call function: create collection, then indexed, restart server and assert row count
|
||||
expected: row count equals nb
|
||||
'''
|
||||
# disable_autoflush()
|
||||
ids = connect.insert(collection, big_entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, default_index)
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# logging.getLogger().info(stats)
|
||||
# pdb.set_trace()
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
assert new_connect.count_entities(collection) == big_nb
|
||||
stats = connect.get_collection_stats(collection)
|
||||
for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
if file["field"] == field_name and file["name"] != "_raw":
|
||||
assert file["data_size"] > 0
|
||||
if file["index_type"] != default_index["index_type"]:
|
||||
assert False
|
||||
else:
|
||||
assert True
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_during_indexing(self, connect, collection, args):
|
||||
'''
|
||||
target: flushing will recover
|
||||
method: call function: create collection, then indexing, restart server and assert row count
|
||||
expected: row count equals nb, server contitue to build index after restart
|
||||
'''
|
||||
# disable_autoflush()
|
||||
loop = 5
|
||||
for i in range(loop):
|
||||
ids = connect.insert(collection, big_entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, default_index, _async=True)
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# logging.getLogger().info(stats)
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count_2 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_2)
|
||||
assert res_count_2 == loop * big_nb
|
||||
status = new_connect._cmd("status")
|
||||
assert json.loads(status)["indexing"] == True
|
||||
# timeout = 100
|
||||
# start_time = time.time()
|
||||
# while time.time() - start_time < timeout:
|
||||
# time.sleep(5)
|
||||
# assert new_connect.count_entities(collection) == loop * big_nb
|
||||
# stats = connect.get_collection_stats(collection)
|
||||
# assert stats["row_count"] == loop * big_nb
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# # logging.getLogger().info(file)
|
||||
# if file["field"] == field_name and file["name"] != "_raw":
|
||||
# assert file["data_size"] > 0
|
||||
# if file["index_type"] != default_index["index_type"]:
|
||||
# continue
|
||||
# for file in stats["partitions"][0]["segments"][0]["files"]:
|
||||
# if file["field"] == field_name and file["name"] != "_raw":
|
||||
# assert file["data_size"] > 0
|
||||
# if file["index_type"] != default_index["index_type"]:
|
||||
# assert False
|
||||
# else:
|
||||
# assert True
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_delete_flush_during_compacting(self, connect, collection, args):
|
||||
'''
|
||||
target: verify server work after restart during compaction
|
||||
method: call function: create collection, then delete/flush/compacting, restart server and assert row count
|
||||
call `compact` again, compact pass
|
||||
expected: row count equals (nb - delete_length)
|
||||
'''
|
||||
# disable_autoflush()
|
||||
ids = connect.insert(collection, big_entities)
|
||||
connect.flush([collection])
|
||||
delete_length = 1000
|
||||
loop = 10
|
||||
for i in range(loop):
|
||||
delete_ids = ids[i*delete_length:(i+1)*delete_length]
|
||||
delete_res = connect.delete_entity_by_id(collection, delete_ids)
|
||||
connect.flush([collection])
|
||||
connect.compact(collection, _async=True)
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
assert res_count == big_nb - delete_length*loop
|
||||
info = connect.get_collection_stats(collection)
|
||||
size_old = info["partitions"][0]["segments"][0]["data_size"]
|
||||
logging.getLogger().info(size_old)
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count_2 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_2)
|
||||
assert res_count_2 == big_nb - delete_length*loop
|
||||
info = connect.get_collection_stats(collection)
|
||||
size_before = info["partitions"][0]["segments"][0]["data_size"]
|
||||
status = connect.compact(collection)
|
||||
assert status.OK()
|
||||
for index, item_id in enumerate(get_ids):
|
||||
assert_equal_vector(res[index], vectors[item_id])
|
||||
info = connect.get_collection_stats(collection)
|
||||
size_after = info["partitions"][0]["segments"][0]["data_size"]
|
||||
assert size_before > size_after
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_insert_during_flushing_multi_collections(self, connect, args):
|
||||
'''
|
||||
target: flushing will recover
|
||||
method: call function: create collections, then insert/flushing, restart server and assert row count
|
||||
expected: row count equals 0
|
||||
'''
|
||||
# disable_autoflush()
|
||||
collection_num = 2
|
||||
collection_list = []
|
||||
for i in range(collection_num):
|
||||
collection_name = gen_unique_str(uid)
|
||||
collection_list.append(collection_name)
|
||||
connect.create_collection(collection_name, default_fields)
|
||||
ids = connect.insert(collection_name, big_entities)
|
||||
connect.flush(collection_list, _async=True)
|
||||
res_count = connect.count_entities(collection_list[-1])
|
||||
logging.getLogger().info(res_count)
|
||||
if res_count < big_nb:
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count_2 = new_connect.count_entities(collection_list[-1])
|
||||
logging.getLogger().info(res_count_2)
|
||||
timeout = 300
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
count_list = []
|
||||
break_flag = True
|
||||
for index, name in enumerate(collection_list):
|
||||
tmp_count = new_connect.count_entities(name)
|
||||
count_list.append(tmp_count)
|
||||
logging.getLogger().info(count_list)
|
||||
if tmp_count != big_nb:
|
||||
break_flag = False
|
||||
break
|
||||
if break_flag == True:
|
||||
break
|
||||
time.sleep(10)
|
||||
for name in collection_list:
|
||||
assert new_connect.count_entities(name) == big_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_insert_during_flushing_multi_partitions(self, connect, collection, args):
|
||||
'''
|
||||
target: flushing will recover
|
||||
method: call function: create collection/partition, then insert/flushing, restart server and assert row count
|
||||
expected: row count equals 0
|
||||
'''
|
||||
# disable_autoflush()
|
||||
partitions_num = 2
|
||||
partitions = []
|
||||
for i in range(partitions_num):
|
||||
tag_tmp = gen_unique_str()
|
||||
partitions.append(tag_tmp)
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
ids = connect.insert(collection, big_entities, partition_tag=tag_tmp)
|
||||
connect.flush([collection], _async=True)
|
||||
res_count = connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count)
|
||||
if res_count < big_nb:
|
||||
# restart server
|
||||
assert restart_server(args["service_name"])
|
||||
# assert row count again
|
||||
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
res_count_2 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_2)
|
||||
timeout = 300
|
||||
start_time = time.time()
|
||||
while new_connect.count_entities(collection) != big_nb * 2 and (time.time() - start_time < timeout):
|
||||
time.sleep(10)
|
||||
logging.getLogger().info(new_connect.count_entities(collection))
|
||||
res_count_3 = new_connect.count_entities(collection)
|
||||
logging.getLogger().info(res_count_3)
|
||||
assert res_count_3 == big_nb * 2
|
|
@ -5,28 +5,15 @@ import logging
|
|||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
COMPACT_TIMEOUT = 180
|
||||
nprobe = 1
|
||||
top_k = 1
|
||||
tag = "1970_01_01"
|
||||
nb = 6000
|
||||
nq = 2
|
||||
segment_row_count = 5000
|
||||
entity = gen_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
default_binary_fields = gen_binary_default_fields()
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim), "metric_type":"L2",
|
||||
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, default_dim), "metric_type":"L2",
|
||||
"params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
|
@ -34,12 +21,12 @@ default_single_query = {
|
|||
default_binary_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {binary_field_name: {"topk": 10, "query": gen_binary_vectors(1, dim), "metric_type":"JACCARD",
|
||||
"params": {"nprobe": 10}}}}
|
||||
{"vector": {binary_field_name: {"topk": 10, "query": gen_binary_vectors(1, default_dim),
|
||||
"metric_type":"JACCARD", "params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
default_query, default_query_vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
|
||||
default_query, default_query_vecs = gen_query_vectors(binary_field_name, default_binary_entities, 1, 2)
|
||||
|
||||
|
||||
def ip_query():
|
||||
|
@ -83,6 +70,14 @@ class TestCompactBase:
|
|||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ints()
|
||||
)
|
||||
def get_threshold(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_compact_collection_name_invalid(self, connect, get_collection_name):
|
||||
'''
|
||||
|
@ -94,7 +89,20 @@ class TestCompactBase:
|
|||
with pytest.raises(Exception) as e:
|
||||
status = connect.compact(collection_name)
|
||||
# assert not status.OK()
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_compact_threshold_invalid(self, connect, collection, get_threshold):
|
||||
'''
|
||||
target: compact collection with invalid name
|
||||
method: compact with invalid threshold
|
||||
expected: exception raised
|
||||
'''
|
||||
threshold = get_threshold
|
||||
if threshold != None:
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.compact(collection, threshold)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_add_entity_and_compact(self, connect, collection):
|
||||
|
@ -104,7 +112,7 @@ class TestCompactBase:
|
|||
expected: data_size before and after Compact
|
||||
'''
|
||||
# vector = gen_single_vector(dim)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
assert len(ids) == 1
|
||||
connect.flush([collection])
|
||||
# get collection info before compact
|
||||
|
@ -125,8 +133,7 @@ class TestCompactBase:
|
|||
method: add entities and compact collection
|
||||
expected: data_size before and after Compact
|
||||
'''
|
||||
# entities = gen_vector(nb, dim)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
|
@ -147,8 +154,8 @@ class TestCompactBase:
|
|||
method: add entities, delete a few and compact collection
|
||||
expected: status ok, data size maybe is smaller after compact
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -167,7 +174,35 @@ class TestCompactBase:
|
|||
size_after = info["partitions"][0]["data_size"]
|
||||
logging.getLogger().info(size_after)
|
||||
assert(size_before >= size_after)
|
||||
|
||||
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_insert_delete_part_and_compact_threshold(self, connect, collection):
|
||||
'''
|
||||
target: test add entities, delete part of them and compact
|
||||
method: add entities, delete a few and compact collection
|
||||
expected: status ok, data size maybe is smaller after compact
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status.OK()
|
||||
connect.flush([collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(info["partitions"])
|
||||
size_before = info["partitions"][0]["data_size"]
|
||||
logging.getLogger().info(size_before)
|
||||
status = connect.compact(collection, 0.1)
|
||||
assert status.OK()
|
||||
# get collection info after compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(info["partitions"])
|
||||
size_after = info["partitions"][0]["data_size"]
|
||||
logging.getLogger().info(size_after)
|
||||
assert(size_before >= size_after)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_insert_delete_all_and_compact(self, connect, collection):
|
||||
|
@ -176,8 +211,8 @@ class TestCompactBase:
|
|||
method: add entities, delete all and compact collection
|
||||
expected: status ok, no data size in collection info because collection is empty
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
status = connect.delete_entity_by_id(collection, ids)
|
||||
assert status.OK()
|
||||
|
@ -200,14 +235,13 @@ class TestCompactBase:
|
|||
method: add entities, delete half of entities in partition and compact collection
|
||||
expected: status ok, data_size less than the older version
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
assert connect.has_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
assert connect.has_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
info = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(info["partitions"])
|
||||
|
||||
delete_ids = ids[:3000]
|
||||
delete_ids = ids[:default_nb//2]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status.OK()
|
||||
connect.flush([collection])
|
||||
|
@ -242,15 +276,14 @@ class TestCompactBase:
|
|||
expected: status ok, index description no change, data size smaller after compact
|
||||
'''
|
||||
count = 10
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.flush([collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
size_before = info["partitions"][0]["segments"][0]["data_size"]
|
||||
logging.getLogger().info(info["partitions"])
|
||||
delete_ids = ids[:1500]
|
||||
delete_ids = ids[:default_nb//2]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
assert status.OK()
|
||||
connect.flush([collection])
|
||||
|
@ -258,7 +291,6 @@ class TestCompactBase:
|
|||
assert status.OK()
|
||||
# get collection info after compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
logging.getLogger().info(info["partitions"])
|
||||
size_after = info["partitions"][0]["segments"][0]["data_size"]
|
||||
assert(size_before >= size_after)
|
||||
|
||||
|
@ -269,7 +301,7 @@ class TestCompactBase:
|
|||
method: add entity and compact collection twice
|
||||
expected: status ok, data size no change
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
|
@ -295,7 +327,7 @@ class TestCompactBase:
|
|||
method: add entities, delete part and compact collection twice
|
||||
expected: status ok, data size smaller after first compact, no change after second
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(collection, delete_ids)
|
||||
|
@ -345,8 +377,8 @@ class TestCompactBase:
|
|||
method: after compact operation, add entity
|
||||
expected: status ok, entity added
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(collection)
|
||||
|
@ -357,10 +389,10 @@ class TestCompactBase:
|
|||
info = connect.get_collection_stats(collection)
|
||||
size_after = info["partitions"][0]["segments"][0]["data_size"]
|
||||
assert(size_before == size_after)
|
||||
ids = connect.insert(collection, entity)
|
||||
ids = connect.insert(collection, default_entity)
|
||||
connect.flush([collection])
|
||||
res = connect.count_entities(collection)
|
||||
assert res == nb+1
|
||||
assert res == default_nb+1
|
||||
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_index_creation_after_compact(self, connect, collection, get_simple_index):
|
||||
|
@ -369,7 +401,7 @@ class TestCompactBase:
|
|||
method: after compact operation, create index
|
||||
expected: status ok, index description no change
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
status = connect.delete_entity_by_id(collection, ids[:10])
|
||||
assert status.OK()
|
||||
|
@ -387,8 +419,8 @@ class TestCompactBase:
|
|||
method: after compact operation, delete entities
|
||||
expected: status ok, entities deleted
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
status = connect.compact(collection)
|
||||
assert status.OK()
|
||||
|
@ -405,14 +437,15 @@ class TestCompactBase:
|
|||
method: after compact operation, search vector
|
||||
expected: status ok
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
status = connect.compact(collection)
|
||||
assert status.OK()
|
||||
query = copy.deepcopy(default_single_query)
|
||||
query["bool"]["must"][0]["vector"][field_name]["query"] = [entity[-1]["values"][0], entities[-1]["values"][0],
|
||||
entities[-1]["values"][-1]]
|
||||
query["bool"]["must"][0]["vector"][field_name]["query"] = [default_entity[-1]["values"][0],
|
||||
default_entities[-1]["values"][0],
|
||||
default_entities[-1]["values"][-1]]
|
||||
res = connect.search(collection, query)
|
||||
logging.getLogger().debug(res)
|
||||
assert len(res) == len(query["bool"]["must"][0]["vector"][field_name]["query"])
|
||||
|
@ -434,7 +467,7 @@ class TestCompactBinary:
|
|||
method: add vector and compact collection
|
||||
expected: status ok, vector added
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entity)
|
||||
ids = connect.insert(binary_collection, default_binary_entity)
|
||||
assert len(ids) == 1
|
||||
connect.flush([binary_collection])
|
||||
# get collection info before compact
|
||||
|
@ -454,8 +487,8 @@ class TestCompactBinary:
|
|||
method: add entities and compact collection
|
||||
expected: status ok, entities added
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(binary_collection)
|
||||
|
@ -474,8 +507,8 @@ class TestCompactBinary:
|
|||
method: add entities, delete a few and compact collection
|
||||
expected: status ok, data size is smaller after compact
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(binary_collection, delete_ids)
|
||||
|
@ -503,8 +536,8 @@ class TestCompactBinary:
|
|||
method: add entities, delete all and compact collection
|
||||
expected: status ok, no data size in collection info because collection is empty
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
status = connect.delete_entity_by_id(binary_collection, ids)
|
||||
assert status.OK()
|
||||
|
@ -526,7 +559,7 @@ class TestCompactBinary:
|
|||
method: add entity and compact collection twice
|
||||
expected: status ok
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entity)
|
||||
ids = connect.insert(binary_collection, default_binary_entity)
|
||||
assert len(ids) == 1
|
||||
connect.flush([binary_collection])
|
||||
# get collection info before compact
|
||||
|
@ -552,8 +585,8 @@ class TestCompactBinary:
|
|||
method: add entities, delete part and compact collection twice
|
||||
expected: status ok, data size smaller after first compact, no change after second
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
delete_ids = [ids[0], ids[-1]]
|
||||
status = connect.delete_entity_by_id(binary_collection, delete_ids)
|
||||
|
@ -609,7 +642,7 @@ class TestCompactBinary:
|
|||
method: after compact operation, add entity
|
||||
expected: status ok, entity added
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
# get collection info before compact
|
||||
info = connect.get_collection_stats(binary_collection)
|
||||
|
@ -620,10 +653,10 @@ class TestCompactBinary:
|
|||
info = connect.get_collection_stats(binary_collection)
|
||||
size_after = info["partitions"][0]["segments"][0]["data_size"]
|
||||
assert(size_before == size_after)
|
||||
ids = connect.insert(binary_collection, binary_entity)
|
||||
ids = connect.insert(binary_collection, default_binary_entity)
|
||||
connect.flush([binary_collection])
|
||||
res = connect.count_entities(binary_collection)
|
||||
assert res == nb + 1
|
||||
assert res == default_nb + 1
|
||||
|
||||
@pytest.mark.timeout(COMPACT_TIMEOUT)
|
||||
def test_delete_entities_after_compact(self, connect, binary_collection):
|
||||
|
@ -632,7 +665,7 @@ class TestCompactBinary:
|
|||
method: after compact operation, delete entities
|
||||
expected: status ok, entities deleted
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
status = connect.compact(binary_collection)
|
||||
assert status.OK()
|
||||
|
@ -651,16 +684,16 @@ class TestCompactBinary:
|
|||
method: after compact operation, search vector
|
||||
expected: status ok
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([binary_collection])
|
||||
status = connect.compact(binary_collection)
|
||||
assert status.OK()
|
||||
query_vecs = [raw_vectors[0]]
|
||||
distance = jaccard(query_vecs[0], raw_vectors[0])
|
||||
query_vecs = [default_raw_binary_vectors[0]]
|
||||
distance = jaccard(query_vecs[0], default_raw_binary_vectors[0])
|
||||
query = copy.deepcopy(default_binary_single_query)
|
||||
query["bool"]["must"][0]["vector"][binary_field_name]["query"] = [binary_entities[-1]["values"][0],
|
||||
binary_entities[-1]["values"][-1]]
|
||||
query["bool"]["must"][0]["vector"][binary_field_name]["query"] = [default_binary_entities[-1]["values"][0],
|
||||
default_binary_entities[-1]["values"][-1]]
|
||||
|
||||
res = connect.search(binary_collection, query)
|
||||
assert abs(res[0]._distances[0]-distance) <= epsilon
|
||||
|
@ -672,13 +705,14 @@ class TestCompactBinary:
|
|||
method: after compact operation, search vector
|
||||
expected: status ok
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([collection])
|
||||
status = connect.compact(collection)
|
||||
query = ip_query()
|
||||
query["bool"]["must"][0]["vector"][field_name]["query"] = [entity[-1]["values"][0], entities[-1]["values"][0],
|
||||
entities[-1]["values"][-1]]
|
||||
query["bool"]["must"][0]["vector"][field_name]["query"] = [default_entity[-1]["values"][0],
|
||||
default_entities[-1]["values"][0],
|
||||
default_entities[-1]["values"][-1]]
|
||||
res = connect.search(collection, query)
|
||||
logging.getLogger().info(res)
|
||||
assert len(res) == len(query["bool"]["must"][0]["vector"][field_name]["query"])
|
||||
|
|
|
@ -8,15 +8,7 @@ import pytest
|
|||
from utils import *
|
||||
import ujson
|
||||
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
CONFIG_TIMEOUT = 80
|
||||
nprobe = 1
|
||||
top_k = 1
|
||||
tag = "1970_01_01"
|
||||
nb = 6000
|
||||
|
||||
|
||||
class TestCacheConfig:
|
||||
"""
|
||||
|
@ -35,8 +27,8 @@ class TestCacheConfig:
|
|||
'''
|
||||
reset configs so the tests are stable
|
||||
'''
|
||||
relpy = connect.set_config("cache", "cache_size", '4GB')
|
||||
config_value = connect.get_config("cache", "cache_size")
|
||||
relpy = connect.set_config("cache.cache_size", '4GB')
|
||||
config_value = connect.get_config("cache.cache_size")
|
||||
assert config_value == '4GB'
|
||||
#relpy = connect.set_config("cache", "insert_buffer_size", '2GB')
|
||||
#config_value = connect.get_config("cache", "insert_buffer_size")
|
||||
|
@ -52,7 +44,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "cache_size")
|
||||
config_value = connect.get_config(config+str(".cache_size"))
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_cache_size_invalid_child_key(self, connect, collection):
|
||||
|
@ -64,7 +56,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["Cpu_cache_size", "cpu cache_size", "cpucachecapacity"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("cache", config)
|
||||
config_value = connect.get_config("cache."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_cache_size_valid(self, connect, collection):
|
||||
|
@ -73,7 +65,7 @@ class TestCacheConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("cache", "cache_size")
|
||||
config_value = connect.get_config("cache.cache_size")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -86,7 +78,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "insert_buffer_size")
|
||||
config_value = connect.get_config(config+".insert_buffer_size")
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_insert_buffer_size_invalid_child_key(self, connect, collection):
|
||||
|
@ -98,7 +90,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["Insert_buffer size", "insert buffer_size", "insertbuffersize"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("cache", config)
|
||||
config_value = connect.get_config("cache."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_insert_buffer_size_valid(self, connect, collection):
|
||||
|
@ -107,7 +99,7 @@ class TestCacheConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("cache", "insert_buffer_size")
|
||||
config_value = connect.get_config("cache.insert_buffer_size")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -120,7 +112,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["preloadtable", "preload collection "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("cache", config)
|
||||
config_value = connect.get_config("cache."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_preload_collection_valid(self, connect, collection):
|
||||
|
@ -129,7 +121,7 @@ class TestCacheConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("cache", "preload_collection")
|
||||
config_value = connect.get_config("cache.preload_collection")
|
||||
assert config_value == ''
|
||||
|
||||
"""
|
||||
|
@ -165,7 +157,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config(config, "cache_size", '4294967296')
|
||||
relpy = connect.set_config(config+".cache_size", '4294967296')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -179,7 +171,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["abc", 1]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("cache", config, '4294967296')
|
||||
relpy = connect.set_config("cache."+config, '4294967296')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -190,8 +182,8 @@ class TestCacheConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
self.reset_configs(connect)
|
||||
relpy = connect.set_config("cache", "cache_size", '2147483648')
|
||||
config_value = connect.get_config("cache", "cache_size")
|
||||
relpy = connect.set_config("cache.cache_size", '2147483648')
|
||||
config_value = connect.get_config("cache.cache_size")
|
||||
assert config_value == '2GB'
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
|
@ -204,12 +196,12 @@ class TestCacheConfig:
|
|||
'''
|
||||
self.reset_configs(connect)
|
||||
for i in range(20):
|
||||
relpy = connect.set_config("cache", "cache_size", '4294967296')
|
||||
config_value = connect.get_config("cache", "cache_size")
|
||||
relpy = connect.set_config("cache.cache_size", '4294967296')
|
||||
config_value = connect.get_config("cache.cache_size")
|
||||
assert config_value == '4294967296'
|
||||
for i in range(20):
|
||||
relpy = connect.set_config("cache", "cache_size", '2147483648')
|
||||
config_value = connect.get_config("cache", "cache_size")
|
||||
relpy = connect.set_config("cache.cache_size", '2147483648')
|
||||
config_value = connect.get_config("cache.cache_size")
|
||||
assert config_value == '2147483648'
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
|
@ -224,7 +216,7 @@ class TestCacheConfig:
|
|||
invalid_configs = ["Cache_config", "cache config", "cache_Config", "cacheconfig"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config(config, "insert_buffer_size", '1073741824')
|
||||
relpy = connect.set_config(config+".insert_buffer_size", '1073741824')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -235,7 +227,7 @@ class TestCacheConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
self.reset_configs(connect)
|
||||
relpy = connect.set_config("cache", "insert_buffer_size", '2GB')
|
||||
relpy = connect.set_config("cache.insert_buffer_size", '2GB')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.level(2)
|
||||
|
@ -248,10 +240,10 @@ class TestCacheConfig:
|
|||
self.reset_configs(connect)
|
||||
for i in range(20):
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("cache", "insert_buffer_size", '1GB')
|
||||
relpy = connect.set_config("cache.insert_buffer_size", '1GB')
|
||||
for i in range(20):
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("cache", "insert_buffer_size", '2GB')
|
||||
relpy = connect.set_config("cache.insert_buffer_size", '2GB')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -265,7 +257,7 @@ class TestCacheConfig:
|
|||
mem_total = self.get_memory_total(connect)
|
||||
logging.getLogger().info(mem_total)
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("cache", "cache_size", str(int(mem_total + 1)+''))
|
||||
relpy = connect.set_config("cache.cache_size", str(int(mem_total + 1)+''))
|
||||
|
||||
|
||||
|
||||
|
@ -292,7 +284,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Engine_config", "engine config"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "gpu_search_threshold")
|
||||
config_value = connect.get_config(config+".gpu_search_threshold")
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_gpu_search_threshold_invalid_child_key(self, connect, collection):
|
||||
|
@ -306,7 +298,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Gpu_search threshold", "gpusearchthreshold"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("gpu", config)
|
||||
config_value = connect.get_config("gpu."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_gpu_search_threshold_valid(self, connect, collection):
|
||||
|
@ -317,7 +309,7 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
config_value = connect.get_config("gpu", "gpu_search_threshold")
|
||||
config_value = connect.get_config("gpu.gpu_search_threshold")
|
||||
assert config_value
|
||||
|
||||
"""
|
||||
|
@ -336,7 +328,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["abc", 1]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("gpu", config, 1000)
|
||||
relpy = connect.set_config("gpu."+config, 1000)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -351,7 +343,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Engine_config", "engine config"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config(config, "gpu_search_threshold", 1000)
|
||||
relpy = connect.set_config(config+".gpu_search_threshold", 1000)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -363,8 +355,8 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
relpy = connect.set_config("gpu", "gpu_search_threshold", 2000)
|
||||
config_value = connect.get_config("gpu", "gpu_search_threshold")
|
||||
relpy = connect.set_config("gpu.gpu_search_threshold", 2000)
|
||||
config_value = connect.get_config("gpu.gpu_search_threshold")
|
||||
assert config_value == '2000'
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
|
@ -377,10 +369,10 @@ class TestGPUConfig:
|
|||
'''
|
||||
for i in [-1, "1000\n", "1000\t", "1000.0", 1000.35]:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("gpu", "use_blas_threshold", i)
|
||||
relpy = connect.set_config("gpu.use_blas_threshold", i)
|
||||
if str(connect._cmd("mode")) == "GPU":
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("gpu", "gpu_search_threshold", i)
|
||||
relpy = connect.set_config("gpu.gpu_search_threshold", i)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -388,8 +380,8 @@ class TestGPUConfig:
|
|||
'''
|
||||
reset configs so the tests are stable
|
||||
'''
|
||||
relpy = connect.set_config("gpu", "cache_size", 1)
|
||||
config_value = connect.get_config("gpu", "cache_size")
|
||||
relpy = connect.set_config("gpu.cache_size", 1)
|
||||
config_value = connect.get_config("gpu.cache_size")
|
||||
assert config_value == '1'
|
||||
|
||||
#follows can not be changed
|
||||
|
@ -416,7 +408,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "enable")
|
||||
config_value = connect.get_config(config+".enable")
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_gpu_enable_invalid_child_key(self, connect, collection):
|
||||
|
@ -430,7 +422,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Enab_le", "enab_le ", "disable", "true"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("gpu", config)
|
||||
config_value = connect.get_config("gpu."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_gpu_enable_valid(self, connect, collection):
|
||||
|
@ -441,7 +433,7 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
config_value = connect.get_config("gpu", "enable")
|
||||
config_value = connect.get_config("gpu.enable")
|
||||
assert config_value == "true" or config_value == "false"
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -457,7 +449,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "cache_size")
|
||||
config_value = connect.get_config(config+".cache_size")
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_cache_size_invalid_child_key(self, connect, collection):
|
||||
|
@ -471,7 +463,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Cache_capacity", "cachecapacity"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("gpu", config)
|
||||
config_value = connect.get_config("gpu."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_cache_size_valid(self, connect, collection):
|
||||
|
@ -482,7 +474,7 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
config_value = connect.get_config("gpu", "cache_size")
|
||||
config_value = connect.get_config("gpu.cache_size")
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_search_devices_invalid_parent_key(self, connect, collection):
|
||||
|
@ -497,7 +489,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "search_devices")
|
||||
config_value = connect.get_config(config+".search_devices")
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_search_devices_invalid_child_key(self, connect, collection):
|
||||
|
@ -511,7 +503,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Search_resources"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("gpu", config)
|
||||
config_value = connect.get_config("gpu."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_search_devices_valid(self, connect, collection):
|
||||
|
@ -522,7 +514,7 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
config_value = connect.get_config("gpu", "search_devices")
|
||||
config_value = connect.get_config("gpu.search_devices")
|
||||
logging.getLogger().info(config_value)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -538,7 +530,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config(config, "build_index_devices")
|
||||
config_value = connect.get_config(config+".build_index_devices")
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_build_index_devices_invalid_child_key(self, connect, collection):
|
||||
|
@ -552,7 +544,7 @@ class TestGPUConfig:
|
|||
invalid_configs = ["Build_index_resources"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("gpu", config)
|
||||
config_value = connect.get_config("gpu."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_build_index_devices_valid(self, connect, collection):
|
||||
|
@ -563,7 +555,7 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
config_value = connect.get_config("gpu", "build_index_devices")
|
||||
config_value = connect.get_config("gpu.build_index_devices")
|
||||
logging.getLogger().info(config_value)
|
||||
assert config_value
|
||||
|
||||
|
@ -586,7 +578,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config(config, "enable", "true")
|
||||
relpy = connect.set_config(config+".enable", "true")
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -602,7 +594,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("gpu", config, "true")
|
||||
relpy = connect.set_config("gpu."+config, "true")
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -616,7 +608,7 @@ class TestGPUConfig:
|
|||
pytest.skip("Only support GPU mode")
|
||||
for i in [-1, -2, 100]:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("gpu", "enable", i)
|
||||
relpy = connect.set_config("gpu.enable", i)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -631,7 +623,7 @@ class TestGPUConfig:
|
|||
valid_configs = ["off", "False", "0", "nO", "on", "True", 1, "yES"]
|
||||
for config in valid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("gpu", "enable", config)
|
||||
relpy = connect.set_config("gpu.enable", config)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -647,7 +639,7 @@ class TestGPUConfig:
|
|||
"gpu_resource"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config(config, "cache_size", 2)
|
||||
relpy = connect.set_config(config+".cache_size", 2)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -659,7 +651,7 @@ class TestGPUConfig:
|
|||
'''
|
||||
if str(connect._cmd("mode")) == "CPU":
|
||||
pytest.skip("Only support GPU mode")
|
||||
relpy = connect.set_config("gpu", "cache_size", 2)
|
||||
relpy = connect.set_config("gpu.cache_size", 2)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -786,7 +778,7 @@ class TestNetworkConfig:
|
|||
invalid_configs = ["Address", "addresses", "address "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("network", config)
|
||||
config_value = connect.get_config("network."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_address_valid(self, connect, collection):
|
||||
|
@ -795,7 +787,7 @@ class TestNetworkConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("network", "bind.address")
|
||||
config_value = connect.get_config("network.bind.address")
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_get_port_invalid_child_key(self, connect, collection):
|
||||
|
@ -807,7 +799,7 @@ class TestNetworkConfig:
|
|||
invalid_configs = ["Port", "PORT", "port "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("network", config)
|
||||
config_value = connect.get_config("network."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_port_valid(self, connect, collection):
|
||||
|
@ -816,7 +808,7 @@ class TestNetworkConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("network", "http.port")
|
||||
config_value = connect.get_config("network.http.port")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -829,7 +821,7 @@ class TestNetworkConfig:
|
|||
invalid_configs = ["webport", "Web_port", "http port "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("network", config)
|
||||
config_value = connect.get_config("network."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_http_port_valid(self, connect, collection):
|
||||
|
@ -838,7 +830,7 @@ class TestNetworkConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("network", "http.port")
|
||||
config_value = connect.get_config("network.http.port")
|
||||
assert config_value
|
||||
|
||||
"""
|
||||
|
@ -863,7 +855,7 @@ class TestNetworkConfig:
|
|||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("network", "child_key", 19530)
|
||||
relpy = connect.set_config("network.child_key", 19530)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -873,7 +865,7 @@ class TestNetworkConfig:
|
|||
method: call set_config correctly
|
||||
expected: status ok, set successfully
|
||||
'''
|
||||
relpy = connect.set_config("network", "bind.address", '0.0.0.0')
|
||||
relpy = connect.set_config("network.bind.address", '0.0.0.0')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_port_valid(self, connect, collection):
|
||||
|
@ -883,7 +875,7 @@ class TestNetworkConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_port in [1025, 65534, 12345, "19530"]:
|
||||
relpy = connect.set_config("network", "http.port", valid_port)
|
||||
relpy = connect.set_config("network.http.port", valid_port)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_port_invalid(self, connect, collection):
|
||||
|
@ -895,7 +887,7 @@ class TestNetworkConfig:
|
|||
for invalid_port in [1024, 65535, "0", "True", "100000"]:
|
||||
logging.getLogger().info(invalid_port)
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("network", "http.port", invalid_port)
|
||||
relpy = connect.set_config("network.http.port", invalid_port)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_http_port_valid(self, connect, collection):
|
||||
|
@ -905,7 +897,7 @@ class TestNetworkConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_http_port in [1025, 65534, "12345", 19121]:
|
||||
relpy = connect.set_config("network", "http.port", valid_http_port)
|
||||
relpy = connect.set_config("network.http.port", valid_http_port)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_http_port_invalid(self, connect, collection):
|
||||
|
@ -916,7 +908,7 @@ class TestNetworkConfig:
|
|||
'''
|
||||
for invalid_http_port in [1024, 65535, "0", "True", "1000000"]:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("network", "http.port", invalid_http_port)
|
||||
relpy = connect.set_config("network.http.port", invalid_http_port)
|
||||
|
||||
|
||||
class TestGeneralConfig:
|
||||
|
@ -940,7 +932,7 @@ class TestGeneralConfig:
|
|||
invalid_configs = ["backend_Url", "backend-url", "meta uri "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("general", config)
|
||||
config_value = connect.get_config("general."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_meta_uri_valid(self, connect, collection):
|
||||
|
@ -949,7 +941,7 @@ class TestGeneralConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("general", "meta_uri")
|
||||
config_value = connect.get_config("general.meta_uri")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -962,7 +954,7 @@ class TestGeneralConfig:
|
|||
invalid_configs = ["time", "time_zone "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("general", config)
|
||||
config_value = connect.get_config("general."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_timezone_valid(self, connect, collection):
|
||||
|
@ -971,7 +963,7 @@ class TestGeneralConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("general", "timezone")
|
||||
config_value = connect.get_config("general.timezone")
|
||||
assert "UTC" in config_value
|
||||
|
||||
"""
|
||||
|
@ -989,7 +981,7 @@ class TestGeneralConfig:
|
|||
for invalid_timezone in ["utc++8", "UTC++8"]:
|
||||
logging.getLogger().info(invalid_timezone)
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("general", "timezone", invalid_timezone)
|
||||
relpy = connect.set_config("general.timezone", invalid_timezone)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1000,7 +992,7 @@ class TestGeneralConfig:
|
|||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("general", "child_key", 1)
|
||||
relpy = connect.set_config("general.child_key", 1)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1010,7 +1002,7 @@ class TestGeneralConfig:
|
|||
method: call set_config correctly
|
||||
expected: status ok, set successfully
|
||||
'''
|
||||
relpy = connect.set_config("general", "meta_uri", 'sqlite://:@:/')
|
||||
relpy = connect.set_config("general.meta_uri", 'sqlite://:@:/')
|
||||
|
||||
|
||||
class TestStorageConfig:
|
||||
|
@ -1034,7 +1026,7 @@ class TestStorageConfig:
|
|||
invalid_configs = ["Primary_path", "primarypath", "pa_th "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("storage", config)
|
||||
config_value = connect.get_config("storage."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_path_valid(self, connect, collection):
|
||||
|
@ -1043,7 +1035,7 @@ class TestStorageConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("storage", "path")
|
||||
config_value = connect.get_config("storage.path")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1056,7 +1048,7 @@ class TestStorageConfig:
|
|||
invalid_configs = ["autoFlushInterval", "auto_flush", "auto_flush interval "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("storage", config)
|
||||
config_value = connect.get_config("storage."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_auto_flush_interval_valid(self, connect, collection):
|
||||
|
@ -1065,7 +1057,7 @@ class TestStorageConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("storage", "auto_flush_interval")
|
||||
config_value = connect.get_config("storage.auto_flush_interval")
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -1081,7 +1073,7 @@ class TestStorageConfig:
|
|||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("storage", "child_key", "")
|
||||
relpy = connect.set_config("storage.child_key", "")
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1091,7 +1083,7 @@ class TestStorageConfig:
|
|||
method: call set_config correctly
|
||||
expected: status ok, set successfully
|
||||
'''
|
||||
relpy = connect.set_config("storage", "path", '/var/lib/milvus')
|
||||
relpy = connect.set_config("storage.path", '/var/lib/milvus')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_auto_flush_interval_valid(self, connect, collection):
|
||||
|
@ -1102,8 +1094,8 @@ class TestStorageConfig:
|
|||
'''
|
||||
for valid_auto_flush_interval in [2, 1]:
|
||||
logging.getLogger().info(valid_auto_flush_interval)
|
||||
relpy = connect.set_config("storage", "auto_flush_interval", valid_auto_flush_interval)
|
||||
config_value = connect.get_config("storage", "auto_flush_interval")
|
||||
relpy = connect.set_config("storage.auto_flush_interval", valid_auto_flush_interval)
|
||||
config_value = connect.get_config("storage.auto_flush_interval")
|
||||
assert config_value == str(valid_auto_flush_interval)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
|
@ -1115,7 +1107,7 @@ class TestStorageConfig:
|
|||
'''
|
||||
for invalid_auto_flush_interval in [-1, "1.5", "invalid", "1+2"]:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("storage", "auto_flush_interval", invalid_auto_flush_interval)
|
||||
relpy = connect.set_config("storage.auto_flush_interval", invalid_auto_flush_interval)
|
||||
|
||||
|
||||
class TestMetricConfig:
|
||||
|
@ -1139,7 +1131,7 @@ class TestMetricConfig:
|
|||
invalid_configs = ["enablemonitor", "Enable_monitor", "en able "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("metric", config)
|
||||
config_value = connect.get_config("metric."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_enable_valid(self, connect, collection):
|
||||
|
@ -1148,7 +1140,7 @@ class TestMetricConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("metric", "enable")
|
||||
config_value = connect.get_config("metric.enable")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1161,7 +1153,7 @@ class TestMetricConfig:
|
|||
invalid_configs = ["Add ress", "addresses", "add ress "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("metric", config)
|
||||
config_value = connect.get_config("metric."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_address_valid(self, connect, collection):
|
||||
|
@ -1170,7 +1162,7 @@ class TestMetricConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("metric", "address")
|
||||
config_value = connect.get_config("metric.address")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1183,7 +1175,7 @@ class TestMetricConfig:
|
|||
invalid_configs = ["Po_rt", "PO_RT", "po_rt "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("metric", config)
|
||||
config_value = connect.get_config("metric."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_port_valid(self, connect, collection):
|
||||
|
@ -1192,7 +1184,7 @@ class TestMetricConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("metric", "port")
|
||||
config_value = connect.get_config("metric.port")
|
||||
assert config_value
|
||||
|
||||
"""
|
||||
|
@ -1209,7 +1201,7 @@ class TestMetricConfig:
|
|||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("metric", "child_key", 19530)
|
||||
relpy = connect.set_config("metric.child_key", 19530)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_enable_valid(self, connect, collection):
|
||||
|
@ -1219,7 +1211,7 @@ class TestMetricConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_enable in ["false", "true"]:
|
||||
relpy = connect.set_config("metric", "enable", valid_enable)
|
||||
relpy = connect.set_config("metric.enable", valid_enable)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1229,7 +1221,7 @@ class TestMetricConfig:
|
|||
method: call set_config correctly
|
||||
expected: status ok, set successfully
|
||||
'''
|
||||
relpy = connect.set_config("metric", "address", '127.0.0.1')
|
||||
relpy = connect.set_config("metric.address", '127.0.0.1')
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_port_valid(self, connect, collection):
|
||||
|
@ -1239,7 +1231,7 @@ class TestMetricConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_port in [1025, 65534, "19530", "9091"]:
|
||||
relpy = connect.set_config("metric", "port", valid_port)
|
||||
relpy = connect.set_config("metric.port", valid_port)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_port_invalid(self, connect, collection):
|
||||
|
@ -1250,7 +1242,7 @@ class TestMetricConfig:
|
|||
'''
|
||||
for invalid_port in [1024, 65535, "0", "True", "100000"]:
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("metric", "port", invalid_port)
|
||||
relpy = connect.set_config("metric.port", invalid_port)
|
||||
|
||||
|
||||
class TestWALConfig:
|
||||
|
@ -1274,7 +1266,7 @@ class TestWALConfig:
|
|||
invalid_configs = ["enabled", "Enab_le", "enable_"]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("wal", config)
|
||||
config_value = connect.get_config("wal."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_enable_valid(self, connect, collection):
|
||||
|
@ -1283,7 +1275,7 @@ class TestWALConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("wal", "enable")
|
||||
config_value = connect.get_config("wal.enable")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1296,7 +1288,7 @@ class TestWALConfig:
|
|||
invalid_configs = ["recovery-error-ignore", "Recovery error_ignore", "recoveryxerror_ignore "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("wal", config)
|
||||
config_value = connect.get_config("wal."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_recovery_error_ignore_valid(self, connect, collection):
|
||||
|
@ -1305,7 +1297,7 @@ class TestWALConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("wal", "recovery_error_ignore")
|
||||
config_value = connect.get_config("wal.recovery_error_ignore")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1318,7 +1310,7 @@ class TestWALConfig:
|
|||
invalid_configs = ["buffersize", "Buffer size", "buffer size "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("wal", config)
|
||||
config_value = connect.get_config("wal."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_buffer_size_valid(self, connect, collection):
|
||||
|
@ -1327,7 +1319,7 @@ class TestWALConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("wal", "buffer_size")
|
||||
config_value = connect.get_config("wal.buffer_size")
|
||||
assert config_value
|
||||
|
||||
@pytest.mark.level(2)
|
||||
|
@ -1340,7 +1332,7 @@ class TestWALConfig:
|
|||
invalid_configs = ["wal", "Wal_path", "wal_path "]
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(Exception) as e:
|
||||
config_value = connect.get_config("wal", config)
|
||||
config_value = connect.get_config("wal."+config)
|
||||
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
def test_get_wal_path_valid(self, connect, collection):
|
||||
|
@ -1349,7 +1341,7 @@ class TestWALConfig:
|
|||
method: call get_config correctly
|
||||
expected: status ok
|
||||
'''
|
||||
config_value = connect.get_config("wal", "path")
|
||||
config_value = connect.get_config("wal.path")
|
||||
assert config_value
|
||||
|
||||
"""
|
||||
|
@ -1366,7 +1358,7 @@ class TestWALConfig:
|
|||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
relpy = connect.set_config("wal", "child_key", 256)
|
||||
relpy = connect.set_config("wal.child_key", 256)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_enable_valid(self, connect, collection):
|
||||
|
@ -1376,7 +1368,7 @@ class TestWALConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_enable in ["false", "true"]:
|
||||
relpy = connect.set_config("wal", "enable", valid_enable)
|
||||
relpy = connect.set_config("wal.enable", valid_enable)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_recovery_error_ignore_valid(self, connect, collection):
|
||||
|
@ -1386,7 +1378,7 @@ class TestWALConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_recovery_error_ignore in ["false", "true"]:
|
||||
relpy = connect.set_config("wal", "recovery_error_ignore", valid_recovery_error_ignore)
|
||||
relpy = connect.set_config("wal.recovery_error_ignore", valid_recovery_error_ignore)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
def test_set_buffer_size_valid_A(self, connect, collection):
|
||||
|
@ -1396,7 +1388,7 @@ class TestWALConfig:
|
|||
expected: status ok, set successfully
|
||||
'''
|
||||
for valid_buffer_size in ["64MB", "128MB", "4096MB", "1000MB", "256MB"]:
|
||||
relpy = connect.set_config("wal", "buffer_size", valid_buffer_size)
|
||||
relpy = connect.set_config("wal.buffer_size", valid_buffer_size)
|
||||
|
||||
@pytest.mark.skip(reason="overwrite config file is not supported in ci yet.")
|
||||
@pytest.mark.timeout(CONFIG_TIMEOUT)
|
||||
|
@ -1406,5 +1398,5 @@ class TestWALConfig:
|
|||
method: call set_config correctly
|
||||
expected: status ok, set successfully
|
||||
'''
|
||||
relpy = connect.set_config("wal", "path", "/var/lib/milvus/wal")
|
||||
relpy = connect.set_config("wal.path", "/var/lib/milvus/wal")
|
||||
|
||||
|
|
|
@ -5,32 +5,18 @@ import logging
|
|||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
index_file_size = 10
|
||||
collection_id = "test_flush"
|
||||
DELETE_TIMEOUT = 60
|
||||
nprobe = 1
|
||||
tag = "1970_01_01"
|
||||
top_k = 1
|
||||
nb = 6000
|
||||
tag = "partition_tag"
|
||||
field_name = "float_vector"
|
||||
entity = gen_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
default_single_query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim), "metric_type":"L2","params": {"nprobe": 10}}}}
|
||||
{"vector": {default_float_vec_field_name: {"topk": 10, "query": gen_vectors(1, default_dim),
|
||||
"metric_type": "L2", "params": {"nprobe": 10}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestFlushBase:
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -77,8 +63,8 @@ class TestFlushBase:
|
|||
method: flush collection with no vectors
|
||||
expected: no error raised
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
ids = connect.insert(collection, default_entities)
|
||||
assert len(ids) == default_nb
|
||||
status = connect.delete_entity_by_id(collection, ids)
|
||||
assert status.OK()
|
||||
res = connect.count_entities(collection)
|
||||
|
@ -91,36 +77,33 @@ class TestFlushBase:
|
|||
method: add entities into partition in collection, flush serveral times
|
||||
expected: the length of ids and the collection row count
|
||||
'''
|
||||
# vector = gen_vector(nb, dim)
|
||||
connect.create_partition(id_collection, tag)
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = [i for i in range(nb)]
|
||||
ids = connect.insert(id_collection, entities, ids)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids = connect.insert(id_collection, default_entities, ids)
|
||||
connect.flush([id_collection])
|
||||
res_count = connect.count_entities(id_collection)
|
||||
assert res_count == nb
|
||||
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
assert len(ids) == nb
|
||||
assert res_count == default_nb
|
||||
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
assert len(ids) == default_nb
|
||||
connect.flush([id_collection])
|
||||
res_count = connect.count_entities(id_collection)
|
||||
assert res_count == nb * 2
|
||||
assert res_count == default_nb * 2
|
||||
|
||||
def test_add_partitions_flush(self, connect, id_collection):
|
||||
'''
|
||||
method: add entities into partitions in collection, flush one
|
||||
expected: the length of ids and the collection row count
|
||||
'''
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
tag_new = gen_unique_str()
|
||||
connect.create_partition(id_collection, tag)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
connect.create_partition(id_collection, tag_new)
|
||||
ids = [i for i in range(nb)]
|
||||
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
connect.flush([id_collection])
|
||||
ids = connect.insert(id_collection, entities, ids, partition_tag=tag_new)
|
||||
ids = connect.insert(id_collection, default_entities, ids, partition_tag=tag_new)
|
||||
connect.flush([id_collection])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == 2 * nb
|
||||
assert res == 2 * default_nb
|
||||
|
||||
def test_add_collections_flush(self, connect, id_collection):
|
||||
'''
|
||||
|
@ -130,18 +113,17 @@ class TestFlushBase:
|
|||
collection_new = gen_unique_str()
|
||||
default_fields = gen_default_fields(False)
|
||||
connect.create_collection(collection_new, default_fields)
|
||||
connect.create_partition(id_collection, tag)
|
||||
connect.create_partition(collection_new, tag)
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = [i for i in range(nb)]
|
||||
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
ids = connect.insert(collection_new, entities, ids, partition_tag=tag)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
connect.create_partition(collection_new, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
ids = connect.insert(collection_new, default_entities, ids, partition_tag=default_tag)
|
||||
connect.flush([id_collection])
|
||||
connect.flush([collection_new])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
res = connect.count_entities(collection_new)
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
|
||||
def test_add_collections_fields_flush(self, connect, id_collection, get_filter_field, get_vector_field):
|
||||
'''
|
||||
|
@ -154,22 +136,21 @@ class TestFlushBase:
|
|||
collection_new = gen_unique_str("test_flush")
|
||||
fields = {
|
||||
"fields": [filter_field, vector_field],
|
||||
"segment_row_limit": segment_row_count,
|
||||
"segment_row_limit": default_segment_row_limit,
|
||||
"auto_id": False
|
||||
}
|
||||
connect.create_collection(collection_new, fields)
|
||||
connect.create_partition(id_collection, tag)
|
||||
connect.create_partition(collection_new, tag)
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
entities_new = gen_entities_by_fields(fields["fields"], nb_new, dim)
|
||||
ids = [i for i in range(nb)]
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
connect.create_partition(collection_new, default_tag)
|
||||
entities_new = gen_entities_by_fields(fields["fields"], nb_new, default_dim)
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids_new = [i for i in range(nb_new)]
|
||||
ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
ids = connect.insert(collection_new, entities_new, ids_new, partition_tag=tag)
|
||||
ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
ids = connect.insert(collection_new, entities_new, ids_new, partition_tag=default_tag)
|
||||
connect.flush([id_collection])
|
||||
connect.flush([collection_new])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
res = connect.count_entities(collection_new)
|
||||
assert res == nb_new
|
||||
|
||||
|
@ -178,8 +159,7 @@ class TestFlushBase:
|
|||
method: add entities, flush serveral times
|
||||
expected: no error raised
|
||||
'''
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
for i in range(10):
|
||||
connect.flush([collection])
|
||||
res = connect.count_entities(collection)
|
||||
|
@ -194,15 +174,14 @@ class TestFlushBase:
|
|||
method: add entities
|
||||
expected: no error raised
|
||||
'''
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = [i for i in range(nb)]
|
||||
ids = connect.insert(id_collection, entities, ids)
|
||||
ids = [i for i in range(default_nb)]
|
||||
ids = connect.insert(id_collection, default_entities, ids)
|
||||
timeout = 20
|
||||
start_time = time.time()
|
||||
while (time.time() - start_time < timeout):
|
||||
time.sleep(1)
|
||||
res = connect.count_entities(id_collection)
|
||||
if res == nb:
|
||||
if res == default_nb:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
assert False
|
||||
|
@ -222,23 +201,21 @@ class TestFlushBase:
|
|||
method: add entities, with same ids, count(same ids) < 15, > 15
|
||||
expected: the length of ids and the collection row count
|
||||
'''
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = [i for i in range(nb)]
|
||||
ids = [i for i in range(default_nb)]
|
||||
for i, item in enumerate(ids):
|
||||
if item <= same_ids:
|
||||
ids[i] = 0
|
||||
ids = connect.insert(id_collection, entities, ids)
|
||||
ids = connect.insert(id_collection, default_entities, ids)
|
||||
connect.flush([id_collection])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
|
||||
def test_delete_flush_multiable_times(self, connect, collection):
|
||||
'''
|
||||
method: delete entities, flush serveral times
|
||||
expected: no error raised
|
||||
'''
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
status = connect.delete_entity_by_id(collection, [ids[-1]])
|
||||
assert status.OK()
|
||||
for i in range(10):
|
||||
|
@ -257,7 +234,7 @@ class TestFlushBase:
|
|||
'''
|
||||
ids = []
|
||||
for i in range(5):
|
||||
tmp_ids = connect.insert(collection, entities)
|
||||
tmp_ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
ids.extend(tmp_ids)
|
||||
disable_flush(connect)
|
||||
|
@ -302,15 +279,19 @@ class TestFlushAsync:
|
|||
status = future.result()
|
||||
|
||||
def test_flush_async_long(self, connect, collection):
|
||||
# vectors = gen_vectors(nb, dim)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
future = connect.flush([collection], _async=True)
|
||||
status = future.result()
|
||||
|
||||
def test_flush_async_long_drop_collection(self, connect, collection):
|
||||
for i in range(5):
|
||||
ids = connect.insert(collection, default_entities)
|
||||
future = connect.flush([collection], _async=True)
|
||||
logging.getLogger().info("DROP")
|
||||
connect.drop_collection(collection)
|
||||
|
||||
def test_flush_async(self, connect, collection):
|
||||
nb = 100000
|
||||
vectors = gen_vectors(nb, dim)
|
||||
connect.insert(collection, entities)
|
||||
connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("before")
|
||||
future = connect.flush([collection], _async=True, _callback=self.check_status)
|
||||
logging.getLogger().info("after")
|
||||
|
|
|
@ -7,26 +7,14 @@ import numpy
|
|||
import pytest
|
||||
import sklearn.preprocessing
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
nb = 6000
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
uid = "test_index"
|
||||
BUILD_TIMEOUT = 300
|
||||
nprobe = 1
|
||||
top_k = 5
|
||||
tag = "1970_01_01"
|
||||
NLIST = 4046
|
||||
INVALID_NLIST = 100000000
|
||||
field_name = "float_vector"
|
||||
binary_field_name = "binary_vector"
|
||||
collection_id = "index"
|
||||
default_index_type = "FLAT"
|
||||
entity = gen_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
query, query_vecs = gen_query_vectors(field_name, entities, top_k, 1)
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 1024}, "metric_type": "L2"}
|
||||
field_name = default_float_vec_field_name
|
||||
binary_field_name = default_binary_vec_field_name
|
||||
query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1)
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
|
||||
|
||||
class TestIndexBase:
|
||||
|
@ -46,7 +34,7 @@ class TestIndexBase:
|
|||
params=[
|
||||
1,
|
||||
10,
|
||||
1500
|
||||
1111
|
||||
],
|
||||
)
|
||||
def get_nq(self, request):
|
||||
|
@ -65,9 +53,32 @@ class TestIndexBase:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
def test_create_index_on_field_not_existed(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index on field not existed
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_field_name = gen_unique_str()
|
||||
ids = connect.insert(collection, default_entities)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection, tmp_field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_on_field(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index on other field
|
||||
expected: error raised
|
||||
'''
|
||||
tmp_field_name = "int64"
|
||||
ids = connect.insert(collection, default_entities)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection, tmp_field_name, get_simple_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_no_vectors(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
@ -84,8 +95,8 @@ class TestIndexBase:
|
|||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
|
@ -96,8 +107,8 @@ class TestIndexBase:
|
|||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush()
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
|
@ -117,13 +128,13 @@ class TestIndexBase:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
logging.getLogger().info(connect.get_collection_stats(collection))
|
||||
nq = get_nq
|
||||
index_type = get_simple_index["index_type"]
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
|
||||
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
|
||||
|
@ -135,7 +146,7 @@ class TestIndexBase:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
|
||||
def build(connect):
|
||||
connect.create_index(collection, field_name, default_index)
|
||||
|
@ -158,7 +169,7 @@ class TestIndexBase:
|
|||
, make sure the collection name not in index
|
||||
expected: create index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection_name, field_name, default_index)
|
||||
|
||||
|
@ -171,10 +182,10 @@ class TestIndexBase:
|
|||
expected: create index ok, and count correct
|
||||
'''
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
count = connect.count_entities(collection)
|
||||
assert count == nb
|
||||
assert count == default_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -196,13 +207,13 @@ class TestIndexBase:
|
|||
method: create another index with different index_params after index have been built
|
||||
expected: return code 0, and describe index result equals with the second index params
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}]
|
||||
for index in indexs:
|
||||
connect.create_index(collection, field_name, index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"]
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_ip(self, connect, collection, get_simple_index):
|
||||
|
@ -211,7 +222,7 @@ class TestIndexBase:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
||||
|
@ -232,8 +243,8 @@ class TestIndexBase:
|
|||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection])
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
@ -245,8 +256,8 @@ class TestIndexBase:
|
|||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
connect.flush()
|
||||
get_simple_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
|
@ -259,14 +270,14 @@ class TestIndexBase:
|
|||
expected: return search success
|
||||
'''
|
||||
metric_type = "IP"
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
get_simple_index["metric_type"] = metric_type
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
logging.getLogger().info(connect.get_collection_stats(collection))
|
||||
nq = get_nq
|
||||
index_type = get_simple_index["index_type"]
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, search_params=search_param)
|
||||
query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
|
||||
|
@ -278,7 +289,7 @@ class TestIndexBase:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
|
||||
def build(connect):
|
||||
default_index["metric_type"] = "IP"
|
||||
|
@ -302,7 +313,7 @@ class TestIndexBase:
|
|||
, make sure the collection name not in index
|
||||
expected: return code not equals to 0, create index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
default_index["metric_type"] = "IP"
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_index(collection_name, field_name, default_index)
|
||||
|
@ -316,10 +327,10 @@ class TestIndexBase:
|
|||
'''
|
||||
default_index["metric_type"] = "IP"
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
connect.flush([collection])
|
||||
count = connect.count_entities(collection)
|
||||
assert count == nb
|
||||
assert count == default_nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -342,13 +353,13 @@ class TestIndexBase:
|
|||
method: create another index with different index_params after index have been built
|
||||
expected: return code 0, and describe index result equals with the second index params
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}]
|
||||
for index in indexs:
|
||||
connect.create_index(collection, field_name, index)
|
||||
stats = connect.get_collection_stats(collection)
|
||||
# assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"]
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
|
@ -402,7 +413,7 @@ class TestIndexBase:
|
|||
, make sure the collection name not in index, and then drop it
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
collection_name = gen_unique_str(collection_id)
|
||||
collection_name = gen_unique_str(uid)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_index(collection_name, field_name)
|
||||
|
||||
|
@ -526,7 +537,7 @@ class TestIndexBinary:
|
|||
params=[
|
||||
1,
|
||||
10,
|
||||
1500
|
||||
1111
|
||||
],
|
||||
)
|
||||
def get_nq(self, request):
|
||||
|
@ -545,7 +556,7 @@ class TestIndexBinary:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -555,8 +566,8 @@ class TestIndexBinary:
|
|||
method: create collection, create partition, and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
connect.create_partition(binary_collection, tag)
|
||||
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
|
@ -567,9 +578,9 @@ class TestIndexBinary:
|
|||
expected: return search success
|
||||
'''
|
||||
nq = get_nq
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
query, vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq, metric_type="JACCARD")
|
||||
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD")
|
||||
search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD")
|
||||
logging.getLogger().info(search_param)
|
||||
res = connect.search(binary_collection, query, search_params=search_param)
|
||||
|
@ -583,8 +594,9 @@ class TestIndexBinary:
|
|||
expected: return create_index failure
|
||||
'''
|
||||
# insert 6000 vectors
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
|
||||
if get_l2_index["index_type"] == "BIN_FLAT":
|
||||
res = connect.create_index(binary_collection, binary_field_name, get_l2_index)
|
||||
else:
|
||||
|
@ -603,11 +615,11 @@ class TestIndexBinary:
|
|||
method: create collection and add entities in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
ids = connect.insert(binary_collection, binary_entities)
|
||||
ids = connect.insert(binary_collection, default_binary_entities)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
|
@ -622,13 +634,13 @@ class TestIndexBinary:
|
|||
method: create collection, create partition and add entities in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
connect.create_partition(binary_collection, tag)
|
||||
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
logging.getLogger().info(stats)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
assert len(stats["partitions"]) == 2
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
|
@ -664,14 +676,14 @@ class TestIndexBinary:
|
|||
method: create collection, create partition and add entities in it, create index on collection, call drop collection index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
connect.create_partition(binary_collection, tag)
|
||||
ids = connect.insert(binary_collection, binary_entities, partition_tag=tag)
|
||||
connect.create_partition(binary_collection, default_tag)
|
||||
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
|
||||
connect.flush([binary_collection])
|
||||
connect.create_index(binary_collection, binary_field_name, get_jaccard_index)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
connect.drop_index(binary_collection, binary_field_name)
|
||||
stats = connect.get_collection_stats(binary_collection)
|
||||
assert stats["row_count"] == nb
|
||||
assert stats["row_count"] == default_nb
|
||||
for partition in stats["partitions"]:
|
||||
segments = partition["segments"]
|
||||
if segments:
|
||||
|
@ -714,7 +726,7 @@ class TestIndexInvalid(object):
|
|||
def get_index(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(1)
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_with_invalid_index_params(self, connect, collection, get_index):
|
||||
logging.getLogger().info(get_index)
|
||||
with pytest.raises(Exception) as e:
|
||||
|
@ -760,7 +772,7 @@ class TestIndexAsync:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("start index")
|
||||
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
|
||||
logging.getLogger().info("before result")
|
||||
|
@ -768,6 +780,20 @@ class TestIndexAsync:
|
|||
# TODO:
|
||||
logging.getLogger().info(res)
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_drop(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("start index")
|
||||
future = connect.create_index(collection, field_name, get_simple_index, _async=True)
|
||||
logging.getLogger().info("DROP")
|
||||
connect.drop_collection(collection)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_with_invalid_collectionname(self, connect):
|
||||
collection_name = " "
|
||||
future = connect.create_index(collection_name, field_name, default_index, _async=True)
|
||||
|
@ -781,7 +807,7 @@ class TestIndexAsync:
|
|||
method: create collection and add entities in it, create index
|
||||
expected: return search success
|
||||
'''
|
||||
ids = connect.insert(collection, entities)
|
||||
ids = connect.insert(collection, default_entities)
|
||||
logging.getLogger().info("start index")
|
||||
future = connect.create_index(collection, field_name, get_simple_index, _async=True,
|
||||
_callback=self.check_result)
|
||||
|
|
|
@ -9,11 +9,8 @@ from multiprocessing import Process
|
|||
import sklearn.preprocessing
|
||||
from utils import *
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
collection_id = "test_mix"
|
||||
add_interval_time = 5
|
||||
vectors = gen_vectors(10000, dim)
|
||||
vectors = gen_vectors(10000, default_dim)
|
||||
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
|
||||
vectors = vectors.tolist()
|
||||
top_k = 1
|
||||
|
@ -24,7 +21,6 @@ nlist = 128
|
|||
|
||||
|
||||
class TestMixBase:
|
||||
|
||||
# disable
|
||||
def _test_search_during_createIndex(self, args):
|
||||
loops = 10000
|
||||
|
@ -35,7 +31,7 @@ class TestMixBase:
|
|||
milvus_instance = get_milvus(args["handler"])
|
||||
# milvus_instance.connect(uri=uri)
|
||||
milvus_instance.create_collection({'collection_name': collection,
|
||||
'dimension': dim,
|
||||
'dimension': default_dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': "L2"})
|
||||
for i in range(10):
|
||||
|
@ -88,7 +84,7 @@ class TestMixBase:
|
|||
collection_name = gen_unique_str('test_mix_multi_collections')
|
||||
collection_list.append(collection_name)
|
||||
param = {'collection_name': collection_name,
|
||||
'dimension': dim,
|
||||
'dimension': default_dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_collection(param)
|
||||
|
@ -101,7 +97,7 @@ class TestMixBase:
|
|||
collection_name = gen_unique_str('test_mix_multi_collections')
|
||||
collection_list.append(collection_name)
|
||||
param = {'collection_name': collection_name,
|
||||
'dimension': dim,
|
||||
'dimension': default_dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_collection(param)
|
||||
|
|
|
@ -6,26 +6,11 @@ import logging
|
|||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
|
||||
dim = 128
|
||||
segment_row_count = 5000
|
||||
collection_id = "partition"
|
||||
nprobe = 1
|
||||
tag = "1970_01_01"
|
||||
TIMEOUT = 120
|
||||
nb = 6000
|
||||
tag = "partition_tag"
|
||||
field_name = "float_vector"
|
||||
entity = gen_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vector, binary_entity = gen_binary_entities(1)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
|
||||
class TestCreateBase:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_partition` function
|
||||
|
@ -37,22 +22,24 @@ class TestCreateBase:
|
|||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(600)
|
||||
def test_create_partition_limit(self, connect, collection, args):
|
||||
'''
|
||||
target: test create partitions, check status returned
|
||||
method: call function: create_partition for 4097 times
|
||||
expected: exception raised
|
||||
'''
|
||||
threads_num = 16
|
||||
threads_num = 8
|
||||
threads = []
|
||||
if args["handler"] == "HTTP":
|
||||
pytest.skip("skip in http mode")
|
||||
|
||||
def create(connect, threads_num):
|
||||
for i in range(4096 // threads_num):
|
||||
for i in range(max_partition_num // threads_num):
|
||||
tag_tmp = gen_unique_str()
|
||||
connect.create_partition(collection, tag_tmp)
|
||||
|
||||
|
@ -73,9 +60,9 @@ class TestCreateBase:
|
|||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
|
||||
def test_create_partition_collection_not_existed(self, connect):
|
||||
'''
|
||||
|
@ -85,7 +72,7 @@ class TestCreateBase:
|
|||
'''
|
||||
collection_name = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.create_partition(collection_name, tag)
|
||||
connect.create_partition(collection_name, default_tag)
|
||||
|
||||
def test_create_partition_tag_name_None(self, connect, collection):
|
||||
'''
|
||||
|
@ -103,11 +90,11 @@ class TestCreateBase:
|
|||
method: call function: create_partition, and again
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
tag_name = gen_unique_str()
|
||||
connect.create_partition(collection, tag_name)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert tag in tag_list
|
||||
assert default_tag in tag_list
|
||||
assert tag_name in tag_list
|
||||
assert "_default" in tag_list
|
||||
|
||||
|
@ -117,9 +104,9 @@ class TestCreateBase:
|
|||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, tag)
|
||||
ids = [i for i in range(nb)]
|
||||
insert_ids = connect.insert(id_collection, entities, ids)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.insert(id_collection, default_entities, ids)
|
||||
assert len(insert_ids) == len(ids)
|
||||
|
||||
def test_create_partition_insert_with_tag(self, connect, id_collection):
|
||||
|
@ -128,9 +115,9 @@ class TestCreateBase:
|
|||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, tag)
|
||||
ids = [i for i in range(nb)]
|
||||
insert_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
assert len(insert_ids) == len(ids)
|
||||
|
||||
def test_create_partition_insert_with_tag_not_existed(self, connect, collection):
|
||||
|
@ -140,10 +127,10 @@ class TestCreateBase:
|
|||
expected: status not ok
|
||||
'''
|
||||
tag_new = "tag_new"
|
||||
connect.create_partition(collection, tag)
|
||||
ids = [i for i in range(nb)]
|
||||
connect.create_partition(collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
with pytest.raises(Exception) as e:
|
||||
insert_ids = connect.insert(collection, entities, ids, partition_tag=tag_new)
|
||||
insert_ids = connect.insert(collection, default_entities, ids, partition_tag=tag_new)
|
||||
|
||||
def test_create_partition_insert_same_tags(self, connect, id_collection):
|
||||
'''
|
||||
|
@ -151,14 +138,14 @@ class TestCreateBase:
|
|||
method: call function: create_partition
|
||||
expected: status ok
|
||||
'''
|
||||
connect.create_partition(id_collection, tag)
|
||||
ids = [i for i in range(nb)]
|
||||
insert_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
ids = [(i+nb) for i in range(nb)]
|
||||
new_insert_ids = connect.insert(id_collection, entities, ids, partition_tag=tag)
|
||||
connect.create_partition(id_collection, default_tag)
|
||||
ids = [i for i in range(default_nb)]
|
||||
insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
ids = [(i+default_nb) for i in range(default_nb)]
|
||||
new_insert_ids = connect.insert(id_collection, default_entities, ids, partition_tag=default_tag)
|
||||
connect.flush([id_collection])
|
||||
res = connect.count_entities(id_collection)
|
||||
assert res == nb * 2
|
||||
assert res == default_nb * 2
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_partition_insert_same_tags_two_collections(self, connect, collection):
|
||||
|
@ -167,17 +154,17 @@ class TestCreateBase:
|
|||
method: call function: create_partition
|
||||
expected: status ok, collection length is correct
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
collection_new = gen_unique_str()
|
||||
connect.create_collection(collection_new, default_fields)
|
||||
connect.create_partition(collection_new, tag)
|
||||
ids = connect.insert(collection, entities, partition_tag=tag)
|
||||
ids = connect.insert(collection_new, entities, partition_tag=tag)
|
||||
connect.create_partition(collection_new, default_tag)
|
||||
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
|
||||
ids = connect.insert(collection_new, default_entities, partition_tag=default_tag)
|
||||
connect.flush([collection, collection_new])
|
||||
res = connect.count_entities(collection)
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
res = connect.count_entities(collection_new)
|
||||
assert res == nb
|
||||
assert res == default_nb
|
||||
|
||||
|
||||
class TestShowBase:
|
||||
|
@ -193,9 +180,9 @@ class TestShowBase:
|
|||
method: create partition first, then call function: list_partitions
|
||||
expected: status ok, partition correct
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
res = connect.list_partitions(collection)
|
||||
assert tag in res
|
||||
assert default_tag in res
|
||||
|
||||
def test_list_partitions_no_partition(self, connect, collection):
|
||||
'''
|
||||
|
@ -213,10 +200,10 @@ class TestShowBase:
|
|||
expected: status ok, partitions correct
|
||||
'''
|
||||
tag_new = gen_unique_str()
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.create_partition(collection, tag_new)
|
||||
res = connect.list_partitions(collection)
|
||||
assert tag in res
|
||||
assert default_tag in res
|
||||
assert tag_new in res
|
||||
|
||||
|
||||
|
@ -240,8 +227,8 @@ class TestHasBase:
|
|||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
res = connect.has_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
res = connect.has_partition(collection, default_tag)
|
||||
logging.getLogger().info(res)
|
||||
assert res
|
||||
|
||||
|
@ -251,9 +238,9 @@ class TestHasBase:
|
|||
method: create partition first, then call function: has_partition
|
||||
expected: status ok, result true
|
||||
'''
|
||||
for tag_name in [tag, "tag_new", "tag_new_new"]:
|
||||
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
|
||||
connect.create_partition(collection, tag_name)
|
||||
for tag_name in [tag, "tag_new", "tag_new_new"]:
|
||||
for tag_name in [default_tag, "tag_new", "tag_new_new"]:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
assert res
|
||||
|
||||
|
@ -263,7 +250,7 @@ class TestHasBase:
|
|||
method: then call function: has_partition, with tag not existed
|
||||
expected: status ok, result empty
|
||||
'''
|
||||
res = connect.has_partition(collection, tag)
|
||||
res = connect.has_partition(collection, default_tag)
|
||||
logging.getLogger().info(res)
|
||||
assert not res
|
||||
|
||||
|
@ -274,7 +261,7 @@ class TestHasBase:
|
|||
expected: status not ok
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition("not_existed_collection", tag)
|
||||
res = connect.has_partition("not_existed_collection", default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
|
@ -284,13 +271,9 @@ class TestHasBase:
|
|||
expected: status ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, tag)
|
||||
if isinstance(tag_name, str):
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
assert not res
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.has_partition(collection, tag_name)
|
||||
|
||||
|
||||
class TestDropBase:
|
||||
|
@ -306,11 +289,11 @@ class TestDropBase:
|
|||
method: create partitions first, then call function: drop_partition
|
||||
expected: status ok, no partitions in db
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.drop_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
res = connect.list_partitions(collection)
|
||||
tag_list = []
|
||||
assert tag not in tag_list
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_tag_not_existed(self, connect, collection):
|
||||
'''
|
||||
|
@ -318,7 +301,7 @@ class TestDropBase:
|
|||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
new_tag = "new_tag"
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, new_tag)
|
||||
|
@ -329,10 +312,10 @@ class TestDropBase:
|
|||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
new_collection = gen_unique_str()
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(new_collection, tag)
|
||||
connect.drop_partition(new_collection, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_repeatedly(self, connect, collection):
|
||||
|
@ -341,13 +324,13 @@ class TestDropBase:
|
|||
method: create partitions first, then call function: drop_partition
|
||||
expected: status not ok, no partitions in db
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.drop_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
time.sleep(2)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert tag not in tag_list
|
||||
assert default_tag not in tag_list
|
||||
|
||||
def test_drop_partition_create(self, connect, collection):
|
||||
'''
|
||||
|
@ -355,12 +338,12 @@ class TestDropBase:
|
|||
method: create partitions first, then call function: drop_partition, create_partition
|
||||
expected: status not ok, partition in db
|
||||
'''
|
||||
connect.create_partition(collection, tag)
|
||||
connect.drop_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
connect.drop_partition(collection, default_tag)
|
||||
time.sleep(2)
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
tag_list = connect.list_partitions(collection)
|
||||
assert tag in tag_list
|
||||
assert default_tag in tag_list
|
||||
|
||||
|
||||
class TestNameInvalid(object):
|
||||
|
@ -378,6 +361,7 @@ class TestNameInvalid(object):
|
|||
def get_collection_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_with_invalid_collection_name(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: test drop partition, with invalid collection name, check status returned
|
||||
|
@ -385,10 +369,11 @@ class TestNameInvalid(object):
|
|||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection_name, tag)
|
||||
connect.drop_partition(collection_name, default_tag)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_partition_with_invalid_tag_name(self, connect, collection, get_tag_name):
|
||||
'''
|
||||
target: test drop partition, with invalid tag name, check status returned
|
||||
|
@ -396,10 +381,11 @@ class TestNameInvalid(object):
|
|||
expected: status not ok
|
||||
'''
|
||||
tag_name = get_tag_name
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
connect.drop_partition(collection, tag_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_list_partitions_with_invalid_collection_name(self, connect, collection, get_collection_name):
|
||||
'''
|
||||
target: test show partitions, with invalid collection name, check status returned
|
||||
|
@ -407,6 +393,6 @@ class TestNameInvalid(object):
|
|||
expected: status not ok
|
||||
'''
|
||||
collection_name = get_collection_name
|
||||
connect.create_partition(collection, tag)
|
||||
connect.create_partition(collection, default_tag)
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.list_partitions(collection_name)
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
import time
|
||||
import pdb
|
||||
import threading
|
||||
import logging
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
dim = 128
|
||||
collection_id = "test_wal"
|
||||
segment_row_count = 5000
|
||||
WAL_TIMEOUT = 60
|
||||
tag = "1970_01_01"
|
||||
insert_interval_time = 1.5
|
||||
nb = 6000
|
||||
field_name = "float_vector"
|
||||
entity = gen_entities(1)
|
||||
binary_entity = gen_binary_entities(1)
|
||||
entities = gen_entities(nb)
|
||||
raw_vectors, binary_entities = gen_binary_entities(nb)
|
||||
default_fields = gen_default_fields()
|
||||
|
||||
|
||||
class TestWalBase:
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test WAL functionality
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(WAL_TIMEOUT)
|
||||
def test_wal_server_crashed_recovery(self, connect, collection):
|
||||
'''
|
||||
target: test wal when server crashed unexpectedly and restarted
|
||||
method: add vectors, server killed before flush, restarted server and flush
|
||||
expected: status ok, add request is recovered and vectors added
|
||||
'''
|
||||
ids = connect.insert(collection, entity)
|
||||
connect.flush([collection])
|
||||
res = connect.count_entities(collection)
|
||||
logging.getLogger().info(res) # should be 0 because no auto flush
|
||||
logging.getLogger().info("Stop server and restart")
|
||||
# kill server and restart. auto flush should be set to 15 seconds.
|
||||
# time.sleep(15)
|
||||
connect.flush([collection])
|
||||
res = connect.count_entities(collection)
|
||||
assert res == 1
|
||||
res = connect.get_entity_by_id(collection, [ids[0]])
|
||||
logging.getLogger().info(res)
|
|
@ -13,16 +13,25 @@ from milvus import Milvus, DataType
|
|||
|
||||
port = 19530
|
||||
epsilon = 0.000001
|
||||
namespace = "milvus"
|
||||
|
||||
default_flush_interval = 1
|
||||
big_flush_interval = 1000
|
||||
dimension = 128
|
||||
nb = 6000
|
||||
top_k = 10
|
||||
segment_row_count = 5000
|
||||
default_drop_interval = 3
|
||||
default_dim = 128
|
||||
default_nb = 1200
|
||||
default_top_k = 10
|
||||
max_top_k = 16384
|
||||
max_partition_num = 256
|
||||
default_segment_row_limit = 1000
|
||||
default_server_segment_row_limit = 1024 * 512
|
||||
default_float_vec_field_name = "float_vector"
|
||||
default_binary_vec_field_name = "binary_vector"
|
||||
default_partition_name = "_default"
|
||||
default_tag = "1970_01_01"
|
||||
|
||||
# TODO:
|
||||
# TODO: disable RHNSW_SQ/PQ in 0.11.0
|
||||
all_index_types = [
|
||||
"FLAT",
|
||||
"IVF_FLAT",
|
||||
|
@ -32,6 +41,8 @@ all_index_types = [
|
|||
"HNSW",
|
||||
# "NSG",
|
||||
"ANNOY",
|
||||
"RHNSW_PQ",
|
||||
"RHNSW_SQ",
|
||||
"BIN_FLAT",
|
||||
"BIN_IVF_FLAT"
|
||||
]
|
||||
|
@ -45,6 +56,8 @@ default_index_params = [
|
|||
{"M": 48, "efConstruction": 500},
|
||||
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
|
||||
{"n_trees": 50},
|
||||
{"M": 48, "efConstruction": 500, "PQM": 64},
|
||||
{"M": 48, "efConstruction": 500},
|
||||
{"nlist": 128},
|
||||
{"nlist": 128}
|
||||
]
|
||||
|
@ -66,6 +79,10 @@ def ivf():
|
|||
return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
|
||||
|
||||
|
||||
def skip_pq():
|
||||
return ["IVF_PQ", "RHNSW_PQ", "RHNSW_SQ"]
|
||||
|
||||
|
||||
def binary_metrics():
|
||||
return ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||
|
||||
|
@ -123,6 +140,10 @@ def get_milvus(host, port, uri=None, handler=None, **kwargs):
|
|||
return milvus
|
||||
|
||||
|
||||
def reset_build_index_threshold(connect):
|
||||
connect.set_config("engine", "build_index_threshold", 1024)
|
||||
|
||||
|
||||
def disable_flush(connect):
|
||||
connect.set_config("storage", "auto_flush_interval", big_flush_interval)
|
||||
|
||||
|
@ -204,14 +225,14 @@ def gen_single_filter_fields():
|
|||
fields = []
|
||||
for data_type in DataType:
|
||||
if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
|
||||
fields.append({"field": data_type.name, "type": data_type})
|
||||
fields.append({"name": data_type.name, "type": data_type})
|
||||
return fields
|
||||
|
||||
|
||||
def gen_single_vector_fields():
|
||||
fields = []
|
||||
for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
|
||||
field = {"field": data_type.name, "type": data_type, "params": {"dim": dimension}}
|
||||
field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}}
|
||||
fields.append(field)
|
||||
return fields
|
||||
|
||||
|
@ -219,11 +240,11 @@ def gen_single_vector_fields():
|
|||
def gen_default_fields(auto_id=True):
|
||||
default_fields = {
|
||||
"fields": [
|
||||
{"field": "int64", "type": DataType.INT64},
|
||||
{"field": "float", "type": DataType.FLOAT},
|
||||
{"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": dimension}},
|
||||
{"name": "int64", "type": DataType.INT64},
|
||||
{"name": "float", "type": DataType.FLOAT},
|
||||
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": default_dim}},
|
||||
],
|
||||
"segment_row_limit": segment_row_count,
|
||||
"segment_row_limit": default_segment_row_limit,
|
||||
"auto_id" : auto_id
|
||||
}
|
||||
return default_fields
|
||||
|
@ -232,37 +253,37 @@ def gen_default_fields(auto_id=True):
|
|||
def gen_binary_default_fields(auto_id=True):
|
||||
default_fields = {
|
||||
"fields": [
|
||||
{"field": "int64", "type": DataType.INT64},
|
||||
{"field": "float", "type": DataType.FLOAT},
|
||||
{"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": dimension}}
|
||||
{"name": "int64", "type": DataType.INT64},
|
||||
{"name": "float", "type": DataType.FLOAT},
|
||||
{"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}}
|
||||
],
|
||||
"segment_row_limit": segment_row_count,
|
||||
"segment_row_limit": default_segment_row_limit,
|
||||
"auto_id" : auto_id
|
||||
}
|
||||
return default_fields
|
||||
|
||||
|
||||
def gen_entities(nb, is_normal=False):
|
||||
vectors = gen_vectors(nb, dimension, is_normal)
|
||||
vectors = gen_vectors(nb, default_dim, is_normal)
|
||||
entities = [
|
||||
{"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
|
||||
{"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
|
||||
{"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
|
||||
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
|
||||
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
|
||||
{"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
|
||||
]
|
||||
return entities
|
||||
|
||||
|
||||
def gen_binary_entities(nb):
|
||||
raw_vectors, vectors = gen_binary_vectors(nb, dimension)
|
||||
raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
|
||||
entities = [
|
||||
{"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
|
||||
{"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
|
||||
{"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
|
||||
{"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
|
||||
{"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
|
||||
{"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
|
||||
]
|
||||
return raw_vectors, entities
|
||||
|
||||
|
||||
def gen_entities_by_fields(fields, nb, dimension):
|
||||
def gen_entities_by_fields(fields, nb, dim):
|
||||
entities = []
|
||||
for field in fields:
|
||||
if field["type"] in [DataType.INT32, DataType.INT64]:
|
||||
|
@ -270,9 +291,9 @@ def gen_entities_by_fields(fields, nb, dimension):
|
|||
elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
|
||||
field_value = [3.0 for i in range(nb)]
|
||||
elif field["type"] == DataType.BINARY_VECTOR:
|
||||
field_value = gen_binary_vectors(nb, dimension)[1]
|
||||
field_value = gen_binary_vectors(nb, dim)[1]
|
||||
elif field["type"] == DataType.FLOAT_VECTOR:
|
||||
field_value = gen_vectors(nb, dimension)
|
||||
field_value = gen_vectors(nb, dim)
|
||||
field.update({"values": field_value})
|
||||
entities.append(field)
|
||||
return entities
|
||||
|
@ -316,7 +337,7 @@ def gen_default_vector_expr(default_query):
|
|||
|
||||
def gen_default_term_expr(keyword="term", field="int64", values=None):
|
||||
if values is None:
|
||||
values = [i for i in range(nb // 2)]
|
||||
values = [i for i in range(default_nb // 2)]
|
||||
expr = {keyword: {field: {"values": values}}}
|
||||
return expr
|
||||
|
||||
|
@ -330,7 +351,7 @@ def update_term_expr(src_term, terms):
|
|||
|
||||
def gen_default_range_expr(keyword="range", field="int64", ranges=None):
|
||||
if ranges is None:
|
||||
ranges = {"GT": 1, "LT": nb // 2}
|
||||
ranges = {"GT": 1, "LT": default_nb // 2}
|
||||
expr = {keyword: {field: ranges}}
|
||||
return expr
|
||||
|
||||
|
@ -347,18 +368,18 @@ def gen_invalid_range():
|
|||
{"range": 1},
|
||||
{"range": {}},
|
||||
{"range": []},
|
||||
{"range": {"range": {"int64": {"GT": 0, "LT": nb // 2}}}}
|
||||
{"range": {"range": {"int64": {"GT": 0, "LT": default_nb // 2}}}}
|
||||
]
|
||||
return range
|
||||
|
||||
|
||||
def gen_valid_ranges():
|
||||
ranges = [
|
||||
{"GT": 0, "LT": nb//2},
|
||||
{"GT": nb // 2, "LT": nb*2},
|
||||
{"GT": 0, "LT": default_nb//2},
|
||||
{"GT": default_nb // 2, "LT": default_nb * 2},
|
||||
{"GT": 0},
|
||||
{"LT": nb},
|
||||
{"GT": -1, "LT": top_k},
|
||||
{"LT": default_nb},
|
||||
{"GT": -1, "LT": default_top_k},
|
||||
]
|
||||
return ranges
|
||||
|
||||
|
@ -368,7 +389,7 @@ def gen_invalid_term():
|
|||
{"term": 1},
|
||||
{"term": []},
|
||||
{"term": {}},
|
||||
{"term": {"term": {"int64": {"values": [i for i in range(nb // 2)]}}}}
|
||||
{"term": {"term": {"int64": {"values": [i for i in range(default_nb // 2)]}}}}
|
||||
]
|
||||
return terms
|
||||
|
||||
|
@ -378,7 +399,7 @@ def add_field_default(default_fields, type=DataType.INT64, field_name=None):
|
|||
if field_name is None:
|
||||
field_name = gen_unique_str()
|
||||
field = {
|
||||
"field": field_name,
|
||||
"name": field_name,
|
||||
"type": type
|
||||
}
|
||||
tmp_fields["fields"].append(field)
|
||||
|
@ -391,7 +412,7 @@ def add_field(entities, field_name=None):
|
|||
if field_name is None:
|
||||
field_name = gen_unique_str()
|
||||
field = {
|
||||
"field": field_name,
|
||||
"name": field_name,
|
||||
"type": DataType.INT64,
|
||||
"values": [i for i in range(nb)]
|
||||
}
|
||||
|
@ -401,9 +422,9 @@ def add_field(entities, field_name=None):
|
|||
|
||||
def add_vector_field(entities, is_normal=False):
|
||||
nb = len(entities[0]["values"])
|
||||
vectors = gen_vectors(nb, dimension, is_normal)
|
||||
vectors = gen_vectors(nb, default_dim, is_normal)
|
||||
field = {
|
||||
"field": gen_unique_str(),
|
||||
"name": gen_unique_str(),
|
||||
"type": DataType.FLOAT_VECTOR,
|
||||
"values": vectors
|
||||
}
|
||||
|
@ -434,15 +455,15 @@ def remove_vector_field(entities):
|
|||
def update_field_name(entities, old_name, new_name):
|
||||
tmp_entities = copy.deepcopy(entities)
|
||||
for item in tmp_entities:
|
||||
if item["field"] == old_name:
|
||||
item["field"] = new_name
|
||||
if item["name"] == old_name:
|
||||
item["name"] = new_name
|
||||
return tmp_entities
|
||||
|
||||
|
||||
def update_field_type(entities, old_name, new_name):
|
||||
tmp_entities = copy.deepcopy(entities)
|
||||
for item in tmp_entities:
|
||||
if item["field"] == old_name:
|
||||
if item["name"] == old_name:
|
||||
item["type"] = new_name
|
||||
return tmp_entities
|
||||
|
||||
|
@ -456,21 +477,20 @@ def update_field_value(entities, old_type, new_value):
|
|||
return tmp_entities
|
||||
|
||||
|
||||
def add_vector_field(nb, dimension=dimension):
|
||||
def add_vector_field(nb, dimension=default_dim):
|
||||
field_name = gen_unique_str()
|
||||
field = {
|
||||
"field": field_name,
|
||||
"name": field_name,
|
||||
"type": DataType.FLOAT_VECTOR,
|
||||
"values": gen_vectors(nb, dimension)
|
||||
}
|
||||
return field_name
|
||||
|
||||
|
||||
def gen_segment_row_counts():
|
||||
def gen_segment_row_limits():
|
||||
sizes = [
|
||||
4096,
|
||||
8192,
|
||||
1000000,
|
||||
1024,
|
||||
4096
|
||||
]
|
||||
return sizes
|
||||
|
||||
|
@ -534,12 +554,8 @@ def gen_invalid_strs():
|
|||
# "",
|
||||
# None,
|
||||
"12 s",
|
||||
"BB。A",
|
||||
"c|c",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"pip+",
|
||||
"=c",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
|
@ -616,12 +632,7 @@ def gen_invalid_vectors():
|
|||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"pip+",
|
||||
"=c",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
|
@ -639,7 +650,7 @@ def gen_invaild_search_params():
|
|||
for nprobe in gen_invalid_params():
|
||||
ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
|
||||
search_params.append(ivf_search_params)
|
||||
elif index_type == "HNSW":
|
||||
elif index_type in ["HNSW", "RHNSW_PQ", "RHNSW_SQ"]:
|
||||
for ef in gen_invalid_params():
|
||||
hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
|
||||
search_params.append(hnsw_search_param)
|
||||
|
@ -667,9 +678,13 @@ def gen_invalid_index():
|
|||
index_params.append(index_param)
|
||||
for M in gen_invalid_params():
|
||||
index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
|
||||
index_param = {"index_type": "RHNSW_PQ", "params": {"M": M, "efConstruction": 100}}
|
||||
index_param = {"index_type": "RHNSW_SQ", "params": {"M": M, "efConstruction": 100}}
|
||||
index_params.append(index_param)
|
||||
for efConstruction in gen_invalid_params():
|
||||
index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
|
||||
index_param = {"index_type": "RHNSW_PQ", "params": {"M": 16, "efConstruction": efConstruction}}
|
||||
index_param = {"index_type": "RHNSW_SQ", "params": {"M": 16, "efConstruction": efConstruction}}
|
||||
index_params.append(index_param)
|
||||
for search_length in gen_invalid_params():
|
||||
index_param = {"index_type": "NSG",
|
||||
|
@ -688,6 +703,8 @@ def gen_invalid_index():
|
|||
index_params.append(index_param)
|
||||
index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
|
||||
index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
|
||||
index_params.append({"index_type": "RHNSW_PQ", "params": {"invalid_key": 16, "efConstruction": 100}})
|
||||
index_params.append({"index_type": "RHNSW_SQ", "params": {"invalid_key": 16, "efConstruction": 100}})
|
||||
index_params.append({"index_type": "NSG",
|
||||
"params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
|
||||
"knng": 100}})
|
||||
|
@ -720,7 +737,7 @@ def gen_index():
|
|||
for nlist in nlists \
|
||||
for m in pq_ms]
|
||||
index_params.extend(IVFPQ_params)
|
||||
elif index_type == "HNSW":
|
||||
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
|
||||
hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
|
||||
for M in Ms \
|
||||
for efConstruction in efConstructions]
|
||||
|
@ -763,7 +780,7 @@ def get_search_param(index_type, metric_type="L2"):
|
|||
search_params = {"metric_type": metric_type}
|
||||
if index_type in ivf() or index_type in binary_support():
|
||||
search_params.update({"nprobe": 64})
|
||||
elif index_type == "HNSW":
|
||||
elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
|
||||
search_params.update({"ef": 64})
|
||||
elif index_type == "NSG":
|
||||
search_params.update({"search_length": 100})
|
||||
|
@ -788,7 +805,6 @@ def restart_server(helm_release_name):
|
|||
from kubernetes import client, config
|
||||
client.rest.logger.setLevel(logging.WARNING)
|
||||
|
||||
namespace = "milvus"
|
||||
# service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
|
||||
config.load_kube_config()
|
||||
v1 = client.CoreV1Api()
|
||||
|
@ -802,7 +818,7 @@ def restart_server(helm_release_name):
|
|||
break
|
||||
# v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
|
||||
# status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
|
||||
# print(status_res)
|
||||
logging.getLogger().debug("Pod name: %s" % pod_name)
|
||||
if pod_name is not None:
|
||||
try:
|
||||
v1.delete_namespaced_pod(pod_name, namespace)
|
||||
|
@ -811,24 +827,61 @@ def restart_server(helm_release_name):
|
|||
logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
|
||||
res = False
|
||||
return res
|
||||
time.sleep(5)
|
||||
logging.error("Sleep 10s after pod deleted")
|
||||
time.sleep(10)
|
||||
# check if restart successfully
|
||||
pods = v1.list_namespaced_pod(namespace)
|
||||
for i in pods.items:
|
||||
pod_name_tmp = i.metadata.name
|
||||
if pod_name_tmp.find(helm_release_name) != -1:
|
||||
logging.debug(pod_name_tmp)
|
||||
logging.error(pod_name_tmp)
|
||||
if pod_name_tmp == pod_name:
|
||||
continue
|
||||
elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1:
|
||||
continue
|
||||
else:
|
||||
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
|
||||
logging.error(status_res.status.phase)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time > timeout:
|
||||
ready_break = False
|
||||
while time.time() - start_time <= timeout:
|
||||
logging.error(time.time())
|
||||
status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
|
||||
if status_res.status.phase == "Running":
|
||||
logging.error("Already running")
|
||||
ready_break = True
|
||||
time.sleep(10)
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
time.sleep(1)
|
||||
if time.time() - start_time > timeout:
|
||||
logging.error("Restart pod: %s timeout" % pod_name_tmp)
|
||||
res = False
|
||||
return res
|
||||
if ready_break:
|
||||
break
|
||||
else:
|
||||
logging.error("Pod: %s not found" % helm_release_name)
|
||||
res = False
|
||||
raise Exception("Pod: %s not found" % pod_name)
|
||||
follow = True
|
||||
pretty = True
|
||||
previous = True # bool | Return previous terminated container logs. Defaults to false. (optional)
|
||||
since_seconds = 56 # int | A relative time in seconds before the current time from which to show logs. If this value precedes the time a pod was started, only logs since the pod start will be returned. If this value is in the future, no logs will be returned. Only one of sinceSeconds or sinceTime may be specified. (optional)
|
||||
timestamps = True # bool | If true, add an RFC3339 or RFC3339Nano timestamp at the beginning of every line of log output. Defaults to false. (optional)
|
||||
container = "milvus"
|
||||
# start_time = time.time()
|
||||
# while time.time() - start_time <= timeout:
|
||||
# try:
|
||||
# api_response = v1.read_namespaced_pod_log(pod_name_tmp, namespace, container=container, follow=follow,
|
||||
# pretty=pretty, previous=previous, since_seconds=since_seconds,
|
||||
# timestamps=timestamps)
|
||||
# logging.error(api_response)
|
||||
# return res
|
||||
# except Exception as e:
|
||||
# logging.error("Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
|
||||
# # waiting for server start
|
||||
# time.sleep(5)
|
||||
# # res = False
|
||||
# # return res
|
||||
# if time.time() - start_time > timeout:
|
||||
# logging.error("Restart pod: %s timeout" % pod_name_tmp)
|
||||
# res = False
|
||||
return res
|
||||
|
|
Loading…
Reference in New Issue