Merge remote-tracking branch 'source/0.5.0' into branch-0.5.0

Former-commit-id: 1383a784ce89252082a752ac91cfb6242428cbda
pull/191/head
starlord 2019-10-18 19:50:17 +08:00
commit fb8c3b0753
78 changed files with 10801 additions and 0 deletions

View File

@ -26,6 +26,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-653 - When config check fail, Milvus close without message
- MS-654 - Describe index timeout when building index
- MS-658 - Fix SQ8 Hybrid can't search
- \#9 Change default gpu_cache_capacity to 4
- MS-665 - IVF_SQ8H search crash when no GPU resource in search_resources
- \#20 - C++ sdk example get grpc error
- \#23 - Add unittest to improve code coverage
@ -75,6 +76,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-624 - Re-organize project directory for open-source
- MS-635 - Add compile option to support customized faiss
- MS-660 - add ubuntu_build_deps.sh
- \#18 - Add all test cases
# Milvus 0.4.0 (2019-09-12)

4
tests/milvus-java-test/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
target/
.idea/
test-output/
lib/*

View File

View File

@ -0,0 +1,10 @@
def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) {
if (protocol == "ftp") {
ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [
[configName: "${remoteIP}", transfers: [
[asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true
]
]
}
}
return this

View File

@ -0,0 +1,13 @@
try {
def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true
if (!result) {
sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
}
} catch (exc) {
def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true
if (!result) {
sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
}
throw exc
}

View File

@ -0,0 +1,16 @@
try {
sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts'
sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus'
sh 'helm repo update'
dir ("milvus-helm") {
checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]])
dir ("milvus/milvus-gpu") {
sh "helm install --wait --timeout 300 --set engine.image.tag=${IMAGE_TAG} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-sdk-test --version 0.3.1 ."
}
}
} catch (exc) {
echo 'Helm running failed!'
sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
throw exc
}

View File

@ -0,0 +1,13 @@
timeout(time: 30, unit: 'MINUTES') {
try {
dir ("milvus-java-test") {
sh "mvn clean install"
sh "java -cp \"target/MilvusSDkJavaTest-1.0-SNAPSHOT.jar:lib/*\" com.MainClass -h ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-sdk-test.svc.cluster.local"
}
} catch (exc) {
echo 'Milvus-SDK-Java Integration Test Failed !'
throw exc
}
}

View File

@ -0,0 +1,15 @@
def notify() {
if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
// Send an email only if the build status has changed from green/unstable to red
emailext subject: '$DEFAULT_SUBJECT',
body: '$DEFAULT_CONTENT',
recipientProviders: [
[$class: 'DevelopersRecipientProvider'],
[$class: 'RequesterRecipientProvider']
],
replyTo: '$DEFAULT_REPLYTO',
to: '$DEFAULT_RECIPIENTS'
}
}
return this

View File

@ -0,0 +1,13 @@
timeout(time: 5, unit: 'MINUTES') {
dir ("${PROJECT_NAME}_test") {
if (fileExists('test_out')) {
def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy"
fileTransfer.FileTransfer("test_out/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage')
if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\""
}
} else {
error("Milvus Dev Test Out directory don't exists!")
}
}
}

View File

@ -0,0 +1,110 @@
pipeline {
agent none
options {
timestamps()
}
environment {
SRC_BRANCH = "master"
IMAGE_TAG = "${params.IMAGE_TAG}-release"
HELM_BRANCH = "${params.IMAGE_TAG}"
TEST_URL = "git@192.168.1.105:Test/milvus-java-test.git"
TEST_BRANCH = "${params.IMAGE_TAG}"
}
stages {
stage("Setup env") {
agent {
kubernetes {
label 'dev-test'
defaultContainer 'jnlp'
yaml """
apiVersion: v1
kind: Pod
metadata:
labels:
app: milvus
componet: test
spec:
containers:
- name: milvus-testframework-java
image: registry.zilliz.com/milvus/milvus-java-test:v0.1
command:
- cat
tty: true
volumeMounts:
- name: kubeconf
mountPath: /root/.kube/
readOnly: true
volumes:
- name: kubeconf
secret:
secretName: test-cluster-config
"""
}
}
stages {
stage("Deploy Server") {
steps {
gitlabCommitStatus(name: 'Deloy Server') {
container('milvus-testframework-java') {
script {
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/deploy_server.groovy"
}
}
}
}
}
stage("Integration Test") {
steps {
gitlabCommitStatus(name: 'Integration Test') {
container('milvus-testframework-java') {
script {
print "In integration test stage"
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/integration_test.groovy"
}
}
}
}
}
stage ("Cleanup Env") {
steps {
gitlabCommitStatus(name: 'Cleanup Env') {
container('milvus-testframework-java') {
script {
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy"
}
}
}
}
}
}
post {
always {
container('milvus-testframework-java') {
script {
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy"
}
}
}
success {
script {
echo "Milvus java-sdk test success !"
}
}
aborted {
script {
echo "Milvus java-sdk test aborted !"
}
}
failure {
script {
echo "Milvus java-sdk test failed !"
}
}
}
}
}
}

View File

@ -0,0 +1,13 @@
apiVersion: v1
kind: Pod
metadata:
labels:
app: milvus
componet: testframework-java
spec:
containers:
- name: milvus-testframework-java
image: maven:3.6.2-jdk-8
command:
- cat
tty: true

View File

@ -0,0 +1,2 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4" />

View File

@ -0,0 +1,137 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>milvus</groupId>
<artifactId>MilvusSDkJavaTest</artifactId>
<version>1.0-SNAPSHOT</version>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>copy-dependencies</id>
<phase>package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>lib</outputDirectory>
<overWriteReleases>false</overWriteReleases>
<overWriteSnapshots>true</overWriteSnapshots>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<grpc.version>1.23.0</grpc.version><!-- CURRENT_GRPC_VERSION -->
<protobuf.version>3.9.0</protobuf.version>
<protoc.version>3.9.0</protoc.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<!-- <dependencyManagement>-->
<!-- <dependencies>-->
<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
<!-- <artifactId>grpc-bom</artifactId>-->
<!-- <version>${grpc.version}</version>-->
<!-- <type>pom</type>-->
<!-- <scope>import</scope>-->
<!-- </dependency>-->
<!-- </dependencies>-->
<!-- </dependencyManagement>-->
<repositories>
<repository>
<id>oss.sonatype.org-snapshot</id>
<url>http://oss.sonatype.org/content/repositories/snapshots</url>
<releases>
<enabled>false</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.3</version>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.10</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.9</version>
</dependency>
<!-- <dependency>-->
<!-- <groupId>io.milvus</groupId>-->
<!-- <artifactId>milvus-sdk-java</artifactId>-->
<!-- <version>0.1.0</version>-->
<!-- </dependency>-->
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>0.1.1-SNAPSHOT</version>
</dependency>
<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
<!-- <artifactId>grpc-netty-shaded</artifactId>-->
<!-- <scope>runtime</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
<!-- <artifactId>grpc-protobuf</artifactId>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
<!-- <artifactId>grpc-stub</artifactId>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>javax.annotation</groupId>-->
<!-- <artifactId>javax.annotation-api</artifactId>-->
<!-- <version>1.2</version>-->
<!-- <scope>provided</scope> &lt;!&ndash; not needed at runtime &ndash;&gt;-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
<!-- <artifactId>grpc-testing</artifactId>-->
<!-- <scope>test</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>com.google.protobuf</groupId>-->
<!-- <artifactId>protobuf-java-util</artifactId>-->
<!-- <version>${protobuf.version}</version>-->
<!-- </dependency>-->
</dependencies>
</project>

View File

@ -0,0 +1,147 @@
package com;
import io.milvus.client.*;
import org.apache.commons.cli.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.SkipException;
import org.testng.TestNG;
import org.testng.annotations.DataProvider;
import org.testng.xml.XmlClass;
import org.testng.xml.XmlSuite;
import org.testng.xml.XmlTest;
import java.util.ArrayList;
import java.util.List;
public class MainClass {
private static String host = "127.0.0.1";
private static String port = "19530";
public Integer index_file_size = 50;
public Integer dimension = 128;
public static void setHost(String host) {
MainClass.host = host;
}
public static void setPort(String port) {
MainClass.port = port;
}
@DataProvider(name="DefaultConnectArgs")
public static Object[][] defaultConnectArgs(){
return new Object[][]{{host, port}};
}
@DataProvider(name="ConnectInstance")
public Object[][] connectInstance(){
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
String tableName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, tableName}};
}
@DataProvider(name="DisConnectInstance")
public Object[][] disConnectInstance(){
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
try {
client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
String tableName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, tableName}};
}
@DataProvider(name="Table")
public Object[][] provideTable(){
Object[][] tables = new Object[2][2];
MetricType metricTypes[] = { MetricType.L2, MetricType.IP };
for (Integer i = 0; i < metricTypes.length; ++i) {
String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
Response res = client.createTable(tableSchemaParam);
if (!res.ok()) {
System.out.println(res.getMessage());
throw new SkipException("Table created failed");
}
tables[i] = new Object[]{client, tableName};
}
return tables;
}
public static void main(String[] args) {
CommandLineParser parser = new DefaultParser();
Options options = new Options();
options.addOption("h", "host", true, "milvus-server hostname/ip");
options.addOption("p", "port", true, "milvus-server port");
try {
CommandLine cmd = parser.parse(options, args);
String host = cmd.getOptionValue("host");
if (host != null) {
setHost(host);
}
String port = cmd.getOptionValue("port");
if (port != null) {
setPort(port);
}
System.out.println("Host: "+host+", Port: "+port);
}
catch(ParseException exp) {
System.err.println("Parsing failed. Reason: " + exp.getMessage() );
}
// TestListenerAdapter tla = new TestListenerAdapter();
// TestNG testng = new TestNG();
// testng.setTestClasses(new Class[] { TestPing.class });
// testng.setTestClasses(new Class[] { TestConnect.class });
// testng.addListener(tla);
// testng.run();
XmlSuite suite = new XmlSuite();
suite.setName("TmpSuite");
XmlTest test = new XmlTest(suite);
test.setName("TmpTest");
List<XmlClass> classes = new ArrayList<XmlClass>();
classes.add(new XmlClass("com.TestPing"));
classes.add(new XmlClass("com.TestAddVectors"));
classes.add(new XmlClass("com.TestConnect"));
classes.add(new XmlClass("com.TestDeleteVectors"));
classes.add(new XmlClass("com.TestIndex"));
classes.add(new XmlClass("com.TestSearchVectors"));
classes.add(new XmlClass("com.TestTable"));
classes.add(new XmlClass("com.TestTableCount"));
test.setXmlClasses(classes) ;
List<XmlSuite> suites = new ArrayList<XmlSuite>();
suites.add(suite);
TestNG tng = new TestNG();
tng.setXmlSuites(suites);
tng.run();
}
}

View File

@ -0,0 +1,154 @@
package com;
import io.milvus.client.InsertParam;
import io.milvus.client.InsertResponse;
import io.milvus.client.MilvusClient;
import io.milvus.client.TableParam;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TestAddVectors {
int dimension = 128;
public List<List<Float>> gen_vectors(Integer nb) {
List<List<Float>> xb = new LinkedList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
LinkedList<Float> vector = new LinkedList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
xb.add(vector);
}
return xb;
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
String tableNameNew = tableName + "_";
InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_add_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors(MilvusClient client, String tableName) throws InterruptedException {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.currentThread().sleep(1000);
// Assert table row count
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_timeout(MilvusClient client, String tableName) throws InterruptedException {
int nb = 200000;
List<List<Float>> vectors = gen_vectors(nb);
System.out.println(new Date());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(1).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_big_data(MilvusClient client, String tableName) throws InterruptedException {
int nb = 500000;
List<List<Float>> vectors = gen_vectors(nb);
System.out.println(new Date());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_ids(MilvusClient client, String tableName) throws InterruptedException {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
// Add vectors with ids
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.currentThread().sleep(1000);
// Assert table row count
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
}
// TODO: MS-628
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_ids(MilvusClient client, String tableName) {
int nb = 10;
List<List<Float>> vectors = gen_vectors(nb);
// Add vectors with ids
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb+1)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_dimension(MilvusClient client, String tableName) {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
vectors.get(0).add((float) 0);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_vectors(MilvusClient client, String tableName) {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
vectors.set(0, new ArrayList<>());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_repeatably(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100000;
int loops = 10;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = null;
for (int i = 0; i < loops; ++i ) {
long startTime = System.currentTimeMillis();
res = client.insert(insertParam);
long endTime = System.currentTimeMillis();
System.out.println("Total execution time: " + (endTime-startTime) + "ms");
}
Thread.currentThread().sleep(1000);
// Assert table row count
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb * loops);
}
}

View File

@ -0,0 +1,80 @@
package com;
import io.milvus.client.ConnectParam;
import io.milvus.client.MilvusClient;
import io.milvus.client.MilvusGrpcClient;
import io.milvus.client.Response;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
public class TestConnect {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect(String host, String port){
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
Response res = client.connect(connectParam);
assert(res.ok());
assert(client.connected());
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect_repeat(String host, String port){
MilvusGrpcClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
Response res = client.connect(connectParam);
assert(!res.ok());
assert(client.connected());
}
@Test(dataProvider="InvalidConnectArgs")
public void test_connect_invalid_connect_args(String ip, String port) throws InterruptedException {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(ip)
.withPort(port)
.build();
client.connect(connectParam);
assert(!client.connected());
}
// TODO: MS-615
@DataProvider(name="InvalidConnectArgs")
public Object[][] generate_invalid_connect_args() {
String port = "19530";
String ip = "";
return new Object[][]{
{"1.1.1.1", port},
{"255.255.0.0", port},
{"1.2.2", port},
{"中文", port},
{"www.baidu.com", "100000"},
{"127.0.0.1", "100000"},
{"www.baidu.com", "80"},
};
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_disconnect(MilvusClient client, String tableName){
assert(!client.connected());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_disconnect_repeatably(MilvusClient client, String tableNam){
Response res = null;
try {
res = client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
assert(res.ok());
assert(!client.connected());
}
}

View File

@ -0,0 +1,122 @@
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.*;
public class TestDeleteVectors {
int index_file_size = 50;
int dimension = 128;
public List<List<Float>> gen_vectors(Integer nb) {
List<List<Float>> xb = new LinkedList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
LinkedList<Float> vector = new LinkedList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
xb.add(vector);
}
return xb;
}
public static Date getDeltaDate(int delta) {
Date today = new Date();
Calendar c = Calendar.getInstance();
c.setTime(today);
c.add(Calendar.DAY_OF_MONTH, delta);
return c.getTime();
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_delete_vectors(MilvusClient client, String tableName) throws InterruptedException {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
// Add vectors
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.sleep(1000);
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
Response res_delete = client.deleteByRange(param);
assert(res_delete.ok());
Thread.sleep(1000);
// Assert table row count
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_delete_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
String tableNameNew = tableName + "_";
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableNameNew).build();
Response res_delete = client.deleteByRange(param);
assert(!res_delete.ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_delete_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
Response res_delete = client.deleteByRange(param);
assert(!res_delete.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_delete_vectors_table_empty(MilvusClient client, String tableName) throws InterruptedException {
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
Response res_delete = client.deleteByRange(param);
assert(res_delete.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_delete_vectors_invalid_date_range(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100;
List<List<Float>> vectors = gen_vectors(nb);
// Add vectors
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.sleep(1000);
DateRange dateRange = new DateRange(getDeltaDate(1), getDeltaDate(0));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
Response res_delete = client.deleteByRange(param);
assert(!res_delete.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_delete_vectors_invalid_date_range_1(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
DateRange dateRange = new DateRange(getDeltaDate(2), getDeltaDate(-1));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
Response res_delete = client.deleteByRange(param);
assert(!res_delete.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_delete_vectors_no_result(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.sleep(1000);
DateRange dateRange = new DateRange(getDeltaDate(-3), getDeltaDate(-2));
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
Response res_delete = client.deleteByRange(param);
assert(res_delete.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
}
}

View File

@ -0,0 +1,340 @@
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class TestIndex {
int index_file_size = 10;
int dimension = 128;
int n_list = 1024;
int default_n_list = 16384;
int nb = 100000;
IndexType indexType = IndexType.IVF_SQ8;
IndexType defaultIndexType = IndexType.FLAT;
public List<List<Float>> gen_vectors(Integer nb) {
List<List<Float>> xb = new LinkedList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
LinkedList<Float> vector = new LinkedList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
xb.add(vector);
}
return xb;
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_repeatably(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getNList(), n_list);
Assert.assertEquals(index1.getIndexType(), indexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_FLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.FLAT;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getIndexType(), indexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException {
int nb = 500000;
IndexType indexType = IndexType.IVF_SQ8;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
System.out.println(new Date());
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).withTimeout(1).build();
Response res_create = client.createIndex(createIndexParam);
assert(!res_create.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVFLAT;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getIndexType(), indexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getIndexType(), indexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_IVFSQ8H(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8_H;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getIndexType(), indexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_with_no_vector(MilvusClient client, String tableName) {
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
String tableNameNew = tableName + "_";
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableNameNew).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(!res_create.ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_create_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(!res_create.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_create_index_invalid_n_list(MilvusClient client, String tableName) throws InterruptedException {
int n_list = 0;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(!res_create.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_describe_index(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getNList(), n_list);
Assert.assertEquals(index1.getIndexType(), indexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_alter_index(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
// Create another index
IndexType indexTypeNew = IndexType.IVFLAT;
int n_list_new = n_list + 1;
Index index_new = new Index.Builder().withIndexType(indexTypeNew)
.withNList(n_list_new)
.build();
CreateIndexParam createIndexParamNew = new CreateIndexParam.Builder(tableName).withIndex(index_new).build();
Response res_create_new = client.createIndex(createIndexParamNew);
assert(res_create_new.ok());
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res_create.ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getNList(), n_list_new);
Assert.assertEquals(index1.getIndexType(), indexTypeNew);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_describe_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
String tableNameNew = tableName + "_";
TableParam tableParam = new TableParam.Builder(tableNameNew).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_describe_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_drop_index(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(defaultIndexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
Response res_drop = client.dropIndex(tableParam);
assert(res_drop.ok());
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getNList(), default_n_list);
Assert.assertEquals(index1.getIndexType(), defaultIndexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_drop_index_repeatably(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(defaultIndexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
Response res_create = client.createIndex(createIndexParam);
assert(res_create.ok());
TableParam tableParam = new TableParam.Builder(tableName).build();
Response res_drop = client.dropIndex(tableParam);
res_drop = client.dropIndex(tableParam);
assert(res_drop.ok());
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getNList(), default_n_list);
Assert.assertEquals(index1.getIndexType(), defaultIndexType);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_drop_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
String tableNameNew = tableName + "_";
TableParam tableParam = new TableParam.Builder(tableNameNew).build();
Response res_drop = client.dropIndex(tableParam);
assert(!res_drop.ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_drop_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
Response res_drop = client.dropIndex(tableParam);
assert(!res_drop.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_drop_index_no_index_created(MilvusClient client, String tableName) throws InterruptedException {
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
Response res_drop = client.dropIndex(tableParam);
assert(res_drop.ok());
DescribeIndexResponse res = client.describeIndex(tableParam);
assert(res.getResponse().ok());
Index index1 = res.getIndex().get();
Assert.assertEquals(index1.getNList(), default_n_list);
Assert.assertEquals(index1.getIndexType(), defaultIndexType);
}
}

View File

@ -0,0 +1,221 @@
package com;
import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class TestMix {
private int dimension = 128;
private int nb = 100000;
int n_list = 8192;
int n_probe = 20;
int top_k = 10;
double epsilon = 0.001;
int index_file_size = 20;
public List<Float> normalize(List<Float> w2v){
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
final float norm = (float) Math.sqrt(squareSum);
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
return w2v;
}
public List<List<Float>> gen_vectors(int nb, boolean norm) {
List<List<Float>> xb = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
if (norm == true) {
vector = normalize(vector);
}
xb.add(vector);
}
return xb;
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
int thread_num = 10;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
});
}
executor.awaitQuiescence(100, TimeUnit.SECONDS);
executor.shutdown();
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect_threads(String host, String port) throws InterruptedException {
int thread_num = 100;
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
assert(client.connected());
try {
client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}
executor.awaitQuiescence(100, TimeUnit.SECONDS);
executor.shutdown();
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
int thread_num = 10;
List<List<Float>> vectors = gen_vectors(nb,false);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
InsertResponse res_insert = client.insert(insertParam);
assert (res_insert.getResponse().ok());
});
}
executor.awaitQuiescence(100, TimeUnit.SECONDS);
executor.shutdown();
Thread.sleep(2000);
TableParam tableParam = new TableParam.Builder(tableName).build();
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
int thread_num = 50;
List<List<Float>> vectors = gen_vectors(nb,false);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
InsertResponse res_insert = client.insert(insertParam);
Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
assert (res_insert.getResponse().ok());
});
}
executor.awaitQuiescence(300, TimeUnit.SECONDS);
executor.shutdown();
Thread.sleep(2000);
TableParam tableParam = new TableParam.Builder(tableName).build();
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
int thread_num = 50;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
InsertResponse res_insert = client.insert(insertParam);
assert (res_insert.getResponse().ok());
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
double distance = res.get(0).get(0).getDistance();
if (tableName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (tableName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
});
}
executor.awaitQuiescence(300, TimeUnit.SECONDS);
executor.shutdown();
Thread.sleep(2000);
TableParam tableParam = new TableParam.Builder(tableName).build();
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_create_insert_delete_threads(String host, String port) throws InterruptedException {
int thread_num = 100;
List<List<Float>> vectors = gen_vectors(nb,false);
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
String tableName = RandomStringUtils.randomAlphabetic(10);
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.IP)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
client.createTable(tableSchemaParam);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
Response response = client.dropTable(tableParam);
Assert.assertTrue(response.ok());
try {
client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}
executor.awaitQuiescence(100, TimeUnit.SECONDS);
executor.shutdown();
}
}

View File

@ -0,0 +1,28 @@
package com;
import io.milvus.client.ConnectParam;
import io.milvus.client.MilvusClient;
import io.milvus.client.MilvusGrpcClient;
import io.milvus.client.Response;
import org.testng.annotations.Test;
public class TestPing {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_server_status(String host, String port){
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
Response res = client.serverStatus();
assert (res.ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_server_status_without_connected(MilvusGrpcClient client, String tableName){
Response res = client.serverStatus();
assert (!res.ok());
}
}

View File

@ -0,0 +1,480 @@
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TestSearchVectors {
int index_file_size = 10;
int dimension = 128;
int n_list = 1024;
int default_n_list = 16384;
int nb = 100000;
int n_probe = 20;
int top_k = 10;
double epsilon = 0.001;
IndexType indexType = IndexType.IVF_SQ8;
IndexType defaultIndexType = IndexType.FLAT;
public List<Float> normalize(List<Float> w2v){
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
final float norm = (float) Math.sqrt(squareSum);
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
return w2v;
}
public List<List<Float>> gen_vectors(int nb, boolean norm) {
List<List<Float>> xb = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
if (norm == true) {
vector = normalize(vector);
}
xb.add(vector);
}
return xb;
}
public static Date getDeltaDate(int delta) {
Date today = new Date();
Calendar c = Calendar.getInstance();
c.setTime(today);
c.add(Calendar.DAY_OF_MONTH, delta);
return c.getTime();
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
String tableNameNew = tableName + "_";
int nq = 5;
int nb = 100;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
SearchParam searchParam = new SearchParam.Builder(tableNameNew, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVFLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.size(), nq);
Assert.assertEquals(res_search.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_ids_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVFLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVFLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.size(), nq);
Assert.assertEquals(res_search.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_distance_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVFLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
double distance = res_search.get(0).get(0).getDistance();
if (tableName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (tableName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.size(), nq);
Assert.assertEquals(res_search.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.size(), nq);
Assert.assertEquals(res_search.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_distance_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(default_n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<Double>> res_search = client.search(searchParam).getResultDistancesList();
for (int i = 0; i < nq; i++) {
double distance = res_search.get(i).get(0);
System.out.println(distance);
if (tableName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (tableName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
}
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_index_FLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.FLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.size(), nq);
Assert.assertEquals(res_search.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_FLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.FLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res_search.size(), nq);
Assert.assertEquals(res_search.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.FLAT;
int nb = 100000;
int nq = 1000;
int top_k = 2048;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withTimeout(1).build();
System.out.println(new Date());
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_FLAT_big_data_size(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.FLAT;
int nb = 100000;
int nq = 2000;
int top_k = 2048;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
System.out.println(new Date());
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_distance_FLAT(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.FLAT;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
double distance = res_search.get(0).get(0).getDistance();
if (tableName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (tableName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_invalid_n_probe(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
int n_probe_new = 0;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe_new).withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_invalid_top_k(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
int top_k_new = 0;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k_new).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_search_invalid_query_vectors(MilvusClient client, String tableName) throws InterruptedException {
// IndexType indexType = IndexType.IVF_SQ8;
// int nq = 5;
// List<List<Float>> vectors = gen_vectors(nb, false);
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
// client.insert(insertParam);
// Index index = new Index.Builder().withIndexType(indexType)
// .withNList(n_list)
// .build();
// CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
// client.createIndex(createIndexParam);
// TableParam tableParam = new TableParam.Builder(tableName).build();
// SearchParam searchParam = new SearchParam.Builder(tableName, null).withNProbe(n_probe).withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (!res_search.getResponse().ok());
// }
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_index_range(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<DateRange> dateRange = new ArrayList<>();
dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1)));
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res.size(), nq);
Assert.assertEquals(res.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_range(MilvusClient client, String tableName) throws InterruptedException {
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<DateRange> dateRange = new ArrayList<>();
dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1)));
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res.size(), nq);
Assert.assertEquals(res.get(0).size(), top_k);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_index_range_no_result(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<DateRange> dateRange = new ArrayList<>();
dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1)));
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res.size(), 0);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_range_no_result(MilvusClient client, String tableName) throws InterruptedException {
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<DateRange> dateRange = new ArrayList<>();
dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1)));
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
Assert.assertEquals(res.size(), 0);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_index_range_invalid(MilvusClient client, String tableName) throws InterruptedException {
IndexType indexType = IndexType.IVF_SQ8;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<DateRange> dateRange = new ArrayList<>();
dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1)));
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(indexType)
.withNList(n_list)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
TableParam tableParam = new TableParam.Builder(tableName).build();
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_range_invalid(MilvusClient client, String tableName) throws InterruptedException {
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
List<DateRange> dateRange = new ArrayList<>();
dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1)));
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.insert(insertParam);
Thread.sleep(1000);
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
}

View File

@ -0,0 +1,155 @@
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
public class TestTable {
int index_file_size = 50;
int dimension = 128;
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table(MilvusClient client, String tableName){
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
Response res = client.createTable(tableSchemaParam);
assert(res.ok());
Assert.assertEquals(res.ok(), true);
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table_disconnect(MilvusClient client, String tableName){
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
Response res = client.createTable(tableSchemaParam);
assert(!res.ok());
}
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table_repeatably(MilvusClient client, String tableName){
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
Response res = client.createTable(tableSchemaParam);
Assert.assertEquals(res.ok(), true);
Response res_new = client.createTable(tableSchemaParam);
Assert.assertEquals(res_new.ok(), false);
}
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table_wrong_params(MilvusClient client, String tableName){
Integer dimension = 0;
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
Response res = client.createTable(tableSchemaParam);
System.out.println(res.toString());
Assert.assertEquals(res.ok(), false);
}
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_show_tables(MilvusClient client, String tableName){
Integer tableNum = 10;
ShowTablesResponse res = null;
for (int i = 0; i < tableNum; ++i) {
String tableNameNew = tableName+"_"+Integer.toString(i);
TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
client.createTable(tableSchemaParam);
List<String> tableNames = client.showTables().getTableNames();
Assert.assertTrue(tableNames.contains(tableNameNew));
}
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_show_tables_without_connect(MilvusClient client, String tableName){
ShowTablesResponse res = client.showTables();
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_drop_table(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
Response res = client.dropTable(tableParam);
assert(res.ok());
Thread.currentThread().sleep(1000);
List<String> tableNames = client.showTables().getTableNames();
Assert.assertFalse(tableNames.contains(tableName));
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_drop_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
Response res = client.dropTable(tableParam);
assert(!res.ok());
List<String> tableNames = client.showTables().getTableNames();
Assert.assertTrue(tableNames.contains(tableName));
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_drop_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
Response res = client.dropTable(tableParam);
assert(!res.ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_describe_table(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeTableResponse res = client.describeTable(tableParam);
assert(res.getResponse().ok());
TableSchema tableSchema = res.getTableSchema().get();
Assert.assertEquals(tableSchema.getDimension(), dimension);
Assert.assertEquals(tableSchema.getTableName(), tableName);
Assert.assertEquals(tableSchema.getIndexFileSize(), index_file_size);
Assert.assertEquals(tableSchema.getMetricType().name(), tableName.substring(0,2));
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_describe_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
DescribeTableResponse res = client.describeTable(tableParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_has_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
HasTableResponse res = client.hasTable(tableParam);
assert(res.getResponse().ok());
Assert.assertFalse(res.hasTable());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_has_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
HasTableResponse res = client.hasTable(tableParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_has_table(MilvusClient client, String tableName) throws InterruptedException {
TableParam tableParam = new TableParam.Builder(tableName).build();
HasTableResponse res = client.hasTable(tableParam);
assert(res.getResponse().ok());
Assert.assertTrue(res.hasTable());
}
}

View File

@ -0,0 +1,89 @@
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class TestTableCount {
int index_file_size = 50;
int dimension = 128;
public List<List<Float>> gen_vectors(Integer nb) {
List<List<Float>> xb = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
ArrayList<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
xb.add(vector);
}
return xb;
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_table_count_no_vectors(MilvusClient client, String tableName) {
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_table_count_table_not_existed(MilvusClient client, String tableName) {
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
GetTableRowCountResponse res = client.getTableRowCount(tableParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_table_count_without_connect(MilvusClient client, String tableName) {
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
GetTableRowCountResponse res = client.getTableRowCount(tableParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_table_count(MilvusClient client, String tableName) throws InterruptedException {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
// Add vectors
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();;
client.insert(insertParam);
Thread.currentThread().sleep(1000);
TableParam tableParam = new TableParam.Builder(tableName).build();
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_table_count_multi_tables(MilvusClient client, String tableName) throws InterruptedException {
int nb = 10000;
List<List<Float>> vectors = gen_vectors(nb);
Integer tableNum = 10;
GetTableRowCountResponse res = null;
for (int i = 0; i < tableNum; ++i) {
String tableNameNew = tableName + "_" + Integer.toString(i);
TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
client.createTable(tableSchemaParam);
// Add vectors
InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
client.insert(insertParam);
}
Thread.currentThread().sleep(1000);
for (int i = 0; i < tableNum; ++i) {
String tableNameNew = tableName + "_" + Integer.toString(i);
TableParam tableParam = new TableParam.Builder(tableNameNew).build();
res = client.getTableRowCount(tableParam);
Assert.assertEquals(res.getTableRowCount(), nb);
}
}
}

View File

@ -0,0 +1,8 @@
<suite name="Test-class Suite">
<test name="Test-class test" >
<classes>
<class name="com.TestConnect" />
<class name="com.TestMix" />
</classes>
</test>
</suite>

2
tests/milvus_ann_acc/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__/
logs/

View File

@ -0,0 +1,149 @@
import pdb
import random
import logging
import json
import time, datetime
from multiprocessing import Process
import numpy
import sklearn.preprocessing
from milvus import Milvus, IndexType, MetricType
logger = logging.getLogger("milvus_ann_acc.client")
SERVER_HOST_DEFAULT = "127.0.0.1"
SERVER_PORT_DEFAULT = 19530
def time_wrapper(func):
"""
This decorator prints the execution time for the decorated function.
"""
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
return result
return wrapper
class MilvusClient(object):
def __init__(self, table_name=None, ip=None, port=None):
self._milvus = Milvus()
self._table_name = table_name
try:
if not ip:
self._milvus.connect(
host = SERVER_HOST_DEFAULT,
port = SERVER_PORT_DEFAULT)
else:
self._milvus.connect(
host = ip,
port = port)
except Exception as e:
raise e
def __str__(self):
return 'Milvus table %s' % self._table_name
def check_status(self, status):
if not status.OK():
logger.error(status.message)
raise Exception("Status not ok")
def create_table(self, table_name, dimension, index_file_size, metric_type):
if not self._table_name:
self._table_name = table_name
if metric_type == "l2":
metric_type = MetricType.L2
elif metric_type == "ip":
metric_type = MetricType.IP
else:
logger.error("Not supported metric_type: %s" % metric_type)
self._metric_type = metric_type
create_param = {'table_name': table_name,
'dimension': dimension,
'index_file_size': index_file_size,
"metric_type": metric_type}
status = self._milvus.create_table(create_param)
self.check_status(status)
@time_wrapper
def insert(self, X, ids):
if self._metric_type == MetricType.IP:
logger.info("Set normalize for metric_type: Inner Product")
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
X = X.astype(numpy.float32)
status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids)
self.check_status(status)
return status, result
@time_wrapper
def create_index(self, index_type, nlist):
if index_type == "flat":
index_type = IndexType.FLAT
elif index_type == "ivf_flat":
index_type = IndexType.IVFLAT
elif index_type == "ivf_sq8":
index_type = IndexType.IVF_SQ8
elif index_type == "ivf_sq8h":
index_type = IndexType.IVF_SQ8H
elif index_type == "mix_nsg":
index_type = IndexType.MIX_NSG
index_params = {
"index_type": index_type,
"nlist": nlist,
}
logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600)
self.check_status(status)
def describe_index(self):
return self._milvus.describe_index(self._table_name)
def drop_index(self):
logger.info("Drop index: %s" % self._table_name)
return self._milvus.drop_index(self._table_name)
@time_wrapper
def query(self, X, top_k, nprobe):
if self._metric_type == MetricType.IP:
logger.info("Set normalize for metric_type: Inner Product")
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
X = X.astype(numpy.float32)
status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist())
self.check_status(status)
# logger.info(results[0])
ids = []
for result in results:
tmp_ids = []
for item in result:
tmp_ids.append(item.id)
ids.append(tmp_ids)
return ids
def count(self):
return self._milvus.get_table_row_count(self._table_name)[1]
def delete(self, timeout=60):
logger.info("Start delete table: %s" % self._table_name)
self._milvus.delete_table(self._table_name)
i = 0
while i < timeout:
if self.count():
time.sleep(1)
i = i + 1
else:
break
if i >= timeout:
logger.error("Delete table timeout")
def describe(self):
return self._milvus.describe_table(self._table_name)
def exists_table(self):
return self._milvus.has_table(self._table_name)
@time_wrapper
def preload_table(self):
return self._milvus.preload_table(self._table_name)

View File

@ -0,0 +1,17 @@
datasets:
sift-128-euclidean:
cpu_cache_size: 16
gpu_cache_size: 5
index_file_size: [1024]
nytimes-16-angular:
cpu_cache_size: 16
gpu_cache_size: 5
index_file_size: [1024]
index:
index_types: ['flat', 'ivf_flat', 'ivf_sq8']
nlists: [8092, 16384]
search:
nprobes: [1, 8, 32]
top_ks: [10]

View File

@ -0,0 +1,26 @@
import argparse
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--dataset',
metavar='NAME',
help='the dataset to load training points from',
default='glove-100-angular',
choices=DATASETS.keys())
parser.add_argument(
"-k", "--count",
default=10,
type=positive_int,
help="the number of near neighbours to search for")
parser.add_argument(
'--definitions',
metavar='FILE',
help='load algorithm definitions from FILE',
default='algos.yaml')
parser.add_argument(
'--image-tag',
default=None,
help='pull image first')

View File

@ -0,0 +1,132 @@
import os
import pdb
import time
import random
import sys
import h5py
import numpy
import logging
from logging import handlers
from client import MilvusClient
LOG_FOLDER = "logs"
logger = logging.getLogger("milvus_ann_acc")
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
if not os.path.exists(LOG_FOLDER):
os.system('mkdir -p %s' % LOG_FOLDER)
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10)
fileTimeHandler.suffix = "%Y%m%d.log"
fileTimeHandler.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG)
fileTimeHandler.setFormatter(formatter)
logger.addHandler(fileTimeHandler)
def get_dataset_fn(dataset_name):
file_path = "/test/milvus/ann_hdf5/"
if not os.path.exists(file_path):
raise Exception("%s not exists" % file_path)
return os.path.join(file_path, '%s.hdf5' % dataset_name)
def get_dataset(dataset_name):
hdf5_fn = get_dataset_fn(dataset_name)
hdf5_f = h5py.File(hdf5_fn)
return hdf5_f
def parse_dataset_name(dataset_name):
data_type = dataset_name.split("-")[0]
dimension = int(dataset_name.split("-")[1])
metric = dataset_name.split("-")[-1]
# metric = dataset.attrs['distance']
# dimension = len(dataset["train"][0])
if metric == "euclidean":
metric_type = "l2"
elif metric == "angular":
metric_type = "ip"
return ("ann"+data_type, dimension, metric_type)
def get_table_name(dataset_name, index_file_size):
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
dataset = get_dataset(dataset_name)
table_size = len(dataset["train"])
table_size = str(table_size // 1000000)+"m"
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
return table_name
def main(dataset_name, index_file_size, nlist=16384, force=False):
top_k = 10
nprobes = [32, 128]
dataset = get_dataset(dataset_name)
table_name = get_table_name(dataset_name, index_file_size)
m = MilvusClient(table_name)
if m.exists_table():
if force is True:
logger.info("Re-create table: %s" % table_name)
m.delete()
time.sleep(10)
else:
logger.info("Table name: %s existed" % table_name)
return
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
m.create_table(table_name, dimension, index_file_size, metric_type)
print(m.describe())
vectors = numpy.array(dataset["train"])
query_vectors = numpy.array(dataset["test"])
# m.insert(vectors)
interval = 100000
loops = len(vectors) // interval + 1
for i in range(loops):
start = i*interval
end = min((i+1)*interval, len(vectors))
tmp_vectors = vectors[start:end]
if start < end:
m.insert(tmp_vectors, ids=[i for i in range(start, end)])
time.sleep(60)
print(m.count())
for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]:
m.create_index(index_type, nlist)
print(m.describe_index())
if m.count() != len(vectors):
return
m.preload_table()
true_ids = numpy.array(dataset["neighbors"])
for nprobe in nprobes:
print("nprobe: %s" % nprobe)
sum_radio = 0.0; avg_radio = 0.0
result_ids = m.query(query_vectors, top_k, nprobe)
# print(result_ids[:10])
for index, result_item in enumerate(result_ids):
if len(set(true_ids[index][:top_k])) != len(set(result_item)):
logger.info("Error happened")
# logger.info(query_vectors[index])
# logger.info(true_ids[index][:top_k], result_item)
tmp = set(true_ids[index][:top_k]).intersection(set(result_item))
sum_radio = sum_radio + (len(tmp) / top_k)
avg_radio = round(sum_radio / len(result_ids), 4)
logger.info(avg_radio)
m.drop_index()
if __name__ == "__main__":
print("glove-25-angular")
# main("sift-128-euclidean", 1024, force=True)
for index_file_size in [50, 1024]:
print("Index file size: %d" % index_file_size)
main("glove-25-angular", index_file_size, force=True)
print("sift-128-euclidean")
for index_file_size in [50, 1024]:
print("Index file size: %d" % index_file_size)
main("sift-128-euclidean", index_file_size, force=True)
# m = MilvusClient()

8
tests/milvus_benchmark/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
random_data
benchmark_logs/
db/
logs/
*idmap*.txt
__pycache__/
venv
.idea

View File

@ -0,0 +1,57 @@
# Quick start
## 运行
### 运行示例:
`python3 main.py --image=registry.zilliz.com/milvus/engine:branch-0.3.1-release --run-count=2 --run-type=performance`
### 运行参数:
--image: 容器模式传入镜像名称如传入则运行测试时会先进行pull image基于image生成milvus server容器
--local: 与image参数互斥本地模式连接使用本地启动的milvus server进行测试
--run-count: 重复运行次数
--suites: 测试集配置文件默认使用suites.yaml
--run-type: 测试类型,包括性能--performance、准确性测试--accuracy以及稳定性--stability
### 测试集配置文件:
`operations:
insert:
[
{"table.index_type": "ivf_flat", "server.index_building_threshold": 300, "table.size": 2000000, "table.ni": 100000, "table.dim": 512},
]
build: []
query:
[
{"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 1, "server.use_blas_threshold": 800},
{"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 10, "server.use_blas_threshold": 20},
]`
## 测试结果:
性能:
`INFO:milvus_benchmark.runner:Start warm query, query params: top-k: 1, nq: 1
INFO:milvus_benchmark.client:query run in 19.19s
INFO:milvus_benchmark.runner:Start query, query params: top-k: 64, nq: 10, actually length of vectors: 10
INFO:milvus_benchmark.runner:Start run query, run 1 of 1
INFO:milvus_benchmark.client:query run in 0.2s
INFO:milvus_benchmark.runner:Avarage query time: 0.20
INFO:milvus_benchmark.runner:[[0.2]]`
**│ 10 │ 0.2 │**
准确率:
`INFO:milvus_benchmark.runner:Avarage accuracy: 1.0`

View File

View File

@ -0,0 +1,244 @@
import pdb
import random
import logging
import json
import sys
import time, datetime
from multiprocessing import Process
from milvus import Milvus, IndexType, MetricType
logger = logging.getLogger("milvus_benchmark.client")
SERVER_HOST_DEFAULT = "127.0.0.1"
SERVER_PORT_DEFAULT = 19530
def time_wrapper(func):
"""
This decorator prints the execution time for the decorated function.
"""
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
return result
return wrapper
class MilvusClient(object):
def __init__(self, table_name=None, ip=None, port=None):
self._milvus = Milvus()
self._table_name = table_name
try:
if not ip:
self._milvus.connect(
host = SERVER_HOST_DEFAULT,
port = SERVER_PORT_DEFAULT)
else:
self._milvus.connect(
host = ip,
port = port)
except Exception as e:
raise e
def __str__(self):
return 'Milvus table %s' % self._table_name
def check_status(self, status):
if not status.OK():
logger.error(status.message)
raise Exception("Status not ok")
def create_table(self, table_name, dimension, index_file_size, metric_type):
if not self._table_name:
self._table_name = table_name
if metric_type == "l2":
metric_type = MetricType.L2
elif metric_type == "ip":
metric_type = MetricType.IP
else:
logger.error("Not supported metric_type: %s" % metric_type)
create_param = {'table_name': table_name,
'dimension': dimension,
'index_file_size': index_file_size,
"metric_type": metric_type}
status = self._milvus.create_table(create_param)
self.check_status(status)
@time_wrapper
def insert(self, X, ids=None):
status, result = self._milvus.add_vectors(self._table_name, X, ids)
self.check_status(status)
return status, result
@time_wrapper
def create_index(self, index_type, nlist):
if index_type == "flat":
index_type = IndexType.FLAT
elif index_type == "ivf_flat":
index_type = IndexType.IVFLAT
elif index_type == "ivf_sq8":
index_type = IndexType.IVF_SQ8
elif index_type == "mix_nsg":
index_type = IndexType.MIX_NSG
elif index_type == "ivf_sq8h":
index_type = IndexType.IVF_SQ8H
index_params = {
"index_type": index_type,
"nlist": nlist,
}
logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600)
self.check_status(status)
def describe_index(self):
return self._milvus.describe_index(self._table_name)
def drop_index(self):
logger.info("Drop index: %s" % self._table_name)
return self._milvus.drop_index(self._table_name)
@time_wrapper
def query(self, X, top_k, nprobe):
status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X)
self.check_status(status)
return status, result
def count(self):
return self._milvus.get_table_row_count(self._table_name)[1]
def delete(self, timeout=60):
logger.info("Start delete table: %s" % self._table_name)
self._milvus.delete_table(self._table_name)
i = 0
while i < timeout:
if self.count():
time.sleep(1)
i = i + 1
continue
else:
break
if i < timeout:
logger.error("Delete table timeout")
def describe(self):
return self._milvus.describe_table(self._table_name)
def exists_table(self):
return self._milvus.has_table(self._table_name)
@time_wrapper
def preload_table(self):
return self._milvus.preload_table(self._table_name, timeout=3000)
def fit(table_name, X):
milvus = Milvus()
milvus.connect(host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT)
start = time.time()
status, ids = milvus.add_vectors(table_name, X)
end = time.time()
logger(status, round(end - start, 2))
def fit_concurrent(table_name, process_num, vectors):
processes = []
for i in range(process_num):
p = Process(target=fit, args=(table_name, vectors, ))
processes.append(p)
p.start()
for p in processes:
p.join()
if __name__ == "__main__":
# table_name = "sift_2m_20_128_l2"
table_name = "test_tset1"
m = MilvusClient(table_name)
# m.create_table(table_name, 128, 50, "l2")
print(m.describe())
# print(m.count())
# print(m.describe_index())
insert_vectors = [[random.random() for _ in range(128)] for _ in range(10000)]
for i in range(5):
m.insert(insert_vectors)
print(m.create_index("ivf_sq8h", 16384))
X = [insert_vectors[0]]
top_k = 10
nprobe = 10
print(m.query(X, top_k, nprobe))
# # # print(m.drop_index())
# # print(m.describe_index())
# # sys.exit()
# # # insert_vectors = [[random.random() for _ in range(128)] for _ in range(100000)]
# # # for i in range(100):
# # # m.insert(insert_vectors)
# # # time.sleep(5)
# # # print(m.describe_index())
# # # print(m.drop_index())
# # m.create_index("ivf_sq8h", 16384)
# print(m.count())
# print(m.describe_index())
# sys.exit()
# print(m.create_index("ivf_sq8h", 16384))
# print(m.count())
# print(m.describe_index())
import numpy as np
def mmap_fvecs(fname):
x = np.memmap(fname, dtype='int32', mode='r')
d = x[0]
return x.view('float32').reshape(-1, d + 1)[:, 1:]
print(mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs"))
# SIFT_SRC_QUERY_DATA_DIR = '/poc/yuncong/ann_1000m'
# file_name = SIFT_SRC_QUERY_DATA_DIR+'/'+'query.npy'
# data = numpy.load(file_name)
# query_vectors = data[0:2].tolist()
# print(len(query_vectors))
# results = m.query(query_vectors, 10, 10)
# result_ids = []
# for result in results[1]:
# tmp = []
# for item in result:
# tmp.append(item.id)
# result_ids.append(tmp)
# print(result_ids[0][:10])
# # gt
# file_name = SIFT_SRC_QUERY_DATA_DIR+"/gnd/"+"idx_1M.ivecs"
# a = numpy.fromfile(file_name, dtype='int32')
# d = a[0]
# true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
# print(true_ids[:3, :2])
# print(len(true_ids[0]))
# import numpy as np
# import sklearn.preprocessing
# def mmap_fvecs(fname):
# x = np.memmap(fname, dtype='int32', mode='r')
# d = x[0]
# return x.view('float32').reshape(-1, d + 1)[:, 1:]
# data = mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs")
# print(data[0], len(data[0]), len(data))
# total_size = 10000
# # total_size = 1000000000
# file_size = 1000
# # file_size = 100000
# file_num = total_size // file_size
# for i in range(file_num):
# fname = "/test/milvus/raw_data/deep1b/binary_96_%05d" % i
# print(fname, i*file_size, (i+1)*file_size)
# single_data = data[i*file_size : (i+1)*file_size]
# single_data = sklearn.preprocessing.normalize(single_data, axis=1, norm='l2')
# np.save(fname, single_data)

View File

@ -0,0 +1,28 @@
* GLOBAL:
FORMAT = "%datetime | %level | %logger | %msg"
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log"
ENABLED = true
TO_FILE = true
TO_STANDARD_OUTPUT = false
SUBSECOND_PRECISION = 3
PERFORMANCE_TRACKING = false
MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB
* DEBUG:
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log"
ENABLED = true
* WARNING:
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log"
* TRACE:
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log"
* VERBOSE:
FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
TO_FILE = false
TO_STANDARD_OUTPUT = false
## Error logs
* ERROR:
ENABLED = true
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log"
* FATAL:
ENABLED = true
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log"

View File

@ -0,0 +1,28 @@
cache_config:
cache_insert_data: false
cpu_cache_capacity: 16
gpu_cache_capacity: 6
cpu_cache_threshold: 0.85
db_config:
backend_url: sqlite://:@:/
build_index_gpu: 0
insert_buffer_size: 4
preload_table: null
primary_path: /opt/milvus
secondary_path: null
engine_config:
use_blas_threshold: 20
metric_config:
collector: prometheus
enable_monitor: true
prometheus_config:
port: 8080
resource_config:
resource_pool:
- cpu
- gpu0
server_config:
address: 0.0.0.0
deploy_mode: single
port: 19530
time_zone: UTC+8

View File

@ -0,0 +1,31 @@
server_config:
address: 0.0.0.0
port: 19530
deploy_mode: single
time_zone: UTC+8
db_config:
primary_path: /opt/milvus
secondary_path:
backend_url: sqlite://:@:/
insert_buffer_size: 4
build_index_gpu: 0
preload_table:
metric_config:
enable_monitor: false
collector: prometheus
prometheus_config:
port: 8080
cache_config:
cpu_cache_capacity: 16
cpu_cache_threshold: 0.85
cache_insert_data: false
engine_config:
use_blas_threshold: 20
resource_config:
resource_pool:
- cpu

View File

@ -0,0 +1,33 @@
server_config:
address: 0.0.0.0
port: 19530
deploy_mode: single
time_zone: UTC+8
db_config:
primary_path: /opt/milvus
secondary_path:
backend_url: sqlite://:@:/
insert_buffer_size: 4
build_index_gpu: 0
preload_table:
metric_config:
enable_monitor: false
collector: prometheus
prometheus_config:
port: 8080
cache_config:
cpu_cache_capacity: 16
cpu_cache_threshold: 0.85
cache_insert_data: false
engine_config:
use_blas_threshold: 20
resource_config:
resource_pool:
- cpu
- gpu0
- gpu1

View File

@ -0,0 +1,32 @@
server_config:
address: 0.0.0.0
port: 19530
deploy_mode: single
time_zone: UTC+8
db_config:
primary_path: /opt/milvus
secondary_path:
backend_url: sqlite://:@:/
insert_buffer_size: 4
build_index_gpu: 0
preload_table:
metric_config:
enable_monitor: false
collector: prometheus
prometheus_config:
port: 8080
cache_config:
cpu_cache_capacity: 16
cpu_cache_threshold: 0.85
cache_insert_data: false
engine_config:
use_blas_threshold: 20
resource_config:
resource_pool:
- cpu
- gpu0

View File

@ -0,0 +1,51 @@
import os
import logging
import pdb
import time
import random
from multiprocessing import Process
import numpy as np
from client import MilvusClient
nq = 100000
dimension = 128
run_count = 1
table_name = "sift_10m_1024_128_ip"
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
def do_query(milvus, table_name, top_ks, nqs, nprobe, run_count):
bi_res = []
for index, nq in enumerate(nqs):
tmp_res = []
for top_k in top_ks:
avg_query_time = 0.0
total_query_time = 0.0
vectors = insert_vectors[0:nq]
for i in range(run_count):
start_time = time.time()
status, query_res = milvus.query(vectors, top_k, nprobe)
total_query_time = total_query_time + (time.time() - start_time)
if status.code:
print(status.message)
avg_query_time = round(total_query_time / run_count, 2)
tmp_res.append(avg_query_time)
bi_res.append(tmp_res)
return bi_res
while 1:
milvus_instance = MilvusClient(table_name, ip="192.168.1.197", port=19530)
top_ks = random.sample([x for x in range(1, 100)], 4)
nqs = random.sample([x for x in range(1, 1000)], 3)
nprobe = random.choice([x for x in range(1, 500)])
res = do_query(milvus_instance, table_name, top_ks, nqs, nprobe, run_count)
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
if not status.OK():
logger.error(status.message)
# status = milvus_instance.drop_index()
if not status.OK():
print(status.message)
index_type = "ivf_sq8"
status = milvus_instance.create_index(index_type, 16384)
if not status.OK():
print(status.message)

View File

@ -0,0 +1,261 @@
import os
import logging
import pdb
import time
import random
from multiprocessing import Process
import numpy as np
from client import MilvusClient
import utils
import parser
from runner import Runner
logger = logging.getLogger("milvus_benchmark.docker")
class DockerRunner(Runner):
"""run docker mode"""
def __init__(self, image):
super(DockerRunner, self).__init__()
self.image = image
def run(self, definition, run_type=None):
if run_type == "performance":
for op_type, op_value in definition.items():
# run docker mode
run_count = op_value["run_count"]
run_params = op_value["params"]
container = None
if op_type == "insert":
for index, param in enumerate(run_params):
logger.info("Definition param: %s" % str(param))
table_name = param["table_name"]
volume_name = param["db_path_prefix"]
print(table_name)
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
for k, v in param.items():
if k.startswith("server."):
# Update server config
utils.modify_config(k, v, type="server", db_slave=None)
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
time.sleep(2)
milvus = MilvusClient(table_name)
# Check has table or not
if milvus.exists_table():
milvus.delete()
time.sleep(10)
milvus.create_table(table_name, dimension, index_file_size, metric_type)
res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
logger.info(res)
# wait for file merge
time.sleep(6 * (table_size / 500000))
# Clear up
utils.remove_container(container)
elif op_type == "query":
for index, param in enumerate(run_params):
logger.info("Definition param: %s" % str(param))
table_name = param["dataset"]
volume_name = param["db_path_prefix"]
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
for k, v in param.items():
if k.startswith("server."):
utils.modify_config(k, v, type="server")
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
time.sleep(2)
milvus = MilvusClient(table_name)
logger.debug(milvus._milvus.show_tables())
# Check has table or not
if not milvus.exists_table():
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
continue
# parse index info
index_types = param["index.index_types"]
nlists = param["index.nlists"]
# parse top-k, nq, nprobe
top_ks, nqs, nprobes = parser.search_params_parser(param)
for index_type in index_types:
for nlist in nlists:
result = milvus.describe_index()
logger.info(result)
milvus.create_index(index_type, nlist)
result = milvus.describe_index()
logger.info(result)
# preload index
milvus.preload_table()
logger.info("Start warm up query")
res = self.do_query(milvus, table_name, [1], [1], 1, 1)
logger.info("End warm up query")
# Run query test
for nprobe in nprobes:
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
headers = ["Nprobe/Top-k"]
headers.extend([str(top_k) for top_k in top_ks])
utils.print_table(headers, nqs, res)
utils.remove_container(container)
elif run_type == "accuracy":
"""
{
"dataset": "random_50m_1024_512",
"index.index_types": ["flat", ivf_flat", "ivf_sq8"],
"index.nlists": [16384],
"nprobes": [1, 32, 128],
"nqs": [100],
"top_ks": [1, 64],
"server.use_blas_threshold": 1100,
"server.cpu_cache_capacity": 256
}
"""
for op_type, op_value in definition.items():
if op_type != "query":
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
break
run_count = op_value["run_count"]
run_params = op_value["params"]
container = None
for index, param in enumerate(run_params):
logger.info("Definition param: %s" % str(param))
table_name = param["dataset"]
sift_acc = False
if "sift_acc" in param:
sift_acc = param["sift_acc"]
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
for k, v in param.items():
if k.startswith("server."):
utils.modify_config(k, v, type="server")
volume_name = param["db_path_prefix"]
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
time.sleep(2)
milvus = MilvusClient(table_name)
# Check has table or not
if not milvus.exists_table():
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
continue
# parse index info
index_types = param["index.index_types"]
nlists = param["index.nlists"]
# parse top-k, nq, nprobe
top_ks, nqs, nprobes = parser.search_params_parser(param)
if sift_acc is True:
# preload groundtruth data
true_ids_all = self.get_groundtruth_ids(table_size)
acc_dict = {}
for index_type in index_types:
for nlist in nlists:
result = milvus.describe_index()
logger.info(result)
milvus.create_index(index_type, nlist)
# preload index
milvus.preload_table()
# Run query test
for nprobe in nprobes:
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
for top_k in top_ks:
for nq in nqs:
result_ids = []
id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
(table_name, index_type, nlist, metric_type, nprobe, top_k, nq)
if sift_acc is False:
self.do_query_acc(milvus, table_name, top_k, nq, nprobe, id_prefix)
if index_type != "flat":
# Compute accuracy
base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
(table_name, nlist, metric_type, nprobe, top_k, nq)
avg_acc = self.compute_accuracy(base_name, id_prefix)
logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc))
else:
result_ids = self.do_query_ids(milvus, table_name, top_k, nq, nprobe)
acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value))
# # print accuracy table
# headers = [table_name]
# headers.extend([str(top_k) for top_k in top_ks])
# utils.print_table(headers, nqs, res)
# remove container, and run next definition
logger.info("remove container, and run next definition")
utils.remove_container(container)
elif run_type == "stability":
for op_type, op_value in definition.items():
if op_type != "query":
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
break
run_count = op_value["run_count"]
run_params = op_value["params"]
container = None
for index, param in enumerate(run_params):
logger.info("Definition param: %s" % str(param))
table_name = param["dataset"]
volume_name = param["db_path_prefix"]
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
# set default test time
if "during_time" not in param:
during_time = 100 # seconds
else:
during_time = int(param["during_time"]) * 60
# set default query process num
if "query_process_num" not in param:
query_process_num = 10
else:
query_process_num = int(param["query_process_num"])
for k, v in param.items():
if k.startswith("server."):
utils.modify_config(k, v, type="server")
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
time.sleep(2)
milvus = MilvusClient(table_name)
# Check has table or not
if not milvus.exists_table():
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
continue
start_time = time.time()
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)]
while time.time() < start_time + during_time:
processes = []
# do query
# for i in range(query_process_num):
# milvus_instance = MilvusClient(table_name)
# top_k = random.choice([x for x in range(1, 100)])
# nq = random.choice([x for x in range(1, 100)])
# nprobe = random.choice([x for x in range(1, 1000)])
# # logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
# p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], [nprobe], run_count, ))
# processes.append(p)
# p.start()
# time.sleep(0.1)
# for p in processes:
# p.join()
milvus_instance = MilvusClient(table_name)
top_ks = random.sample([x for x in range(1, 100)], 3)
nqs = random.sample([x for x in range(1, 1000)], 3)
nprobe = random.choice([x for x in range(1, 500)])
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
if int(time.time() - start_time) % 120 == 0:
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
if not status.OK():
logger.error(status)
# status = milvus_instance.drop_index()
# if not status.OK():
# logger.error(status)
# index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
result = milvus.describe_index()
logger.info(result)
milvus_instance.create_index("ivf_sq8", 16384)
utils.remove_container(container)
else:
logger.warning("Run type: %s not supported" % run_type)

View File

@ -0,0 +1,132 @@
import os
import logging
import pdb
import time
import random
from multiprocessing import Process
import numpy as np
from client import MilvusClient
import utils
import parser
from runner import Runner
logger = logging.getLogger("milvus_benchmark.local_runner")
class LocalRunner(Runner):
"""run local mode"""
def __init__(self, ip, port):
super(LocalRunner, self).__init__()
self.ip = ip
self.port = port
def run(self, definition, run_type=None):
if run_type == "performance":
for op_type, op_value in definition.items():
run_count = op_value["run_count"]
run_params = op_value["params"]
if op_type == "insert":
for index, param in enumerate(run_params):
table_name = param["table_name"]
# random_1m_100_512
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
# Check has table or not
if milvus.exists_table():
milvus.delete()
time.sleep(10)
milvus.create_table(table_name, dimension, index_file_size, metric_type)
res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
logger.info(res)
elif op_type == "query":
for index, param in enumerate(run_params):
logger.info("Definition param: %s" % str(param))
table_name = param["dataset"]
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
# parse index info
index_types = param["index.index_types"]
nlists = param["index.nlists"]
# parse top-k, nq, nprobe
top_ks, nqs, nprobes = parser.search_params_parser(param)
for index_type in index_types:
for nlist in nlists:
milvus.create_index(index_type, nlist)
# preload index
milvus.preload_table()
# Run query test
for nprobe in nprobes:
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
headers = [param["dataset"]]
headers.extend([str(top_k) for top_k in top_ks])
utils.print_table(headers, nqs, res)
elif run_type == "stability":
for op_type, op_value in definition.items():
if op_type != "query":
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
break
run_count = op_value["run_count"]
run_params = op_value["params"]
nq = 10000
for index, param in enumerate(run_params):
logger.info("Definition param: %s" % str(param))
table_name = param["dataset"]
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
# set default test time
if "during_time" not in param:
during_time = 100 # seconds
else:
during_time = int(param["during_time"]) * 60
# set default query process num
if "query_process_num" not in param:
query_process_num = 10
else:
query_process_num = int(param["query_process_num"])
milvus = MilvusClient(table_name)
# Check has table or not
if not milvus.exists_table():
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
continue
start_time = time.time()
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
while time.time() < start_time + during_time:
processes = []
# # do query
# for i in range(query_process_num):
# milvus_instance = MilvusClient(table_name)
# top_k = random.choice([x for x in range(1, 100)])
# nq = random.choice([x for x in range(1, 1000)])
# nprobe = random.choice([x for x in range(1, 500)])
# logger.info(nprobe)
# p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], 64, run_count, ))
# processes.append(p)
# p.start()
# time.sleep(0.1)
# for p in processes:
# p.join()
milvus_instance = MilvusClient(table_name)
top_ks = random.sample([x for x in range(1, 100)], 4)
nqs = random.sample([x for x in range(1, 1000)], 3)
nprobe = random.choice([x for x in range(1, 500)])
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
# milvus_instance = MilvusClient(table_name)
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
if not status.OK():
logger.error(status.message)
if (time.time() - start_time) % 300 == 0:
status = milvus_instance.drop_index()
if not status.OK():
logger.error(status.message)
index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
status = milvus_instance.create_index(index_type, 16384)
if not status.OK():
logger.error(status.message)

View File

@ -0,0 +1,131 @@
import os
import sys
import time
import pdb
import argparse
import logging
import utils
from yaml import load, dump
from logging import handlers
from parser import operations_parser
from local_runner import LocalRunner
from docker_runner import DockerRunner
DEFAULT_IMAGE = "milvusdb/milvus:latest"
LOG_FOLDER = "benchmark_logs"
logger = logging.getLogger("milvus_benchmark")
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
if not os.path.exists(LOG_FOLDER):
os.system('mkdir -p %s' % LOG_FOLDER)
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'milvus_benchmark'), "D", 1, 10)
fileTimeHandler.suffix = "%Y%m%d.log"
fileTimeHandler.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG)
fileTimeHandler.setFormatter(formatter)
logger.addHandler(fileTimeHandler)
def positive_int(s):
i = None
try:
i = int(s)
except ValueError:
pass
if not i or i < 1:
raise argparse.ArgumentTypeError("%r is not a positive integer" % s)
return i
# # link random_data if not exists
# def init_env():
# if not os.path.islink(BINARY_DATA_FOLDER):
# try:
# os.symlink(SRC_BINARY_DATA_FOLDER, BINARY_DATA_FOLDER)
# except Exception as e:
# logger.error("Create link failed: %s" % str(e))
# sys.exit()
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--image',
help='use the given image')
parser.add_argument(
'--local',
action='store_true',
help='use local milvus server')
parser.add_argument(
"--run-count",
default=1,
type=positive_int,
help="run each db operation times")
# performance / stability / accuracy test
parser.add_argument(
"--run-type",
default="performance",
help="run type, default performance")
parser.add_argument(
'--suites',
metavar='FILE',
help='load test suites from FILE',
default='suites.yaml')
parser.add_argument(
'--ip',
help='server ip param for local mode',
default='127.0.0.1')
parser.add_argument(
'--port',
help='server port param for local mode',
default='19530')
args = parser.parse_args()
operations = None
# Get all benchmark test suites
if args.suites:
with open(args.suites) as f:
suites_dict = load(f)
f.close()
# With definition order
operations = operations_parser(suites_dict, run_type=args.run_type)
# init_env()
run_params = {"run_count": args.run_count}
if args.image:
# for docker mode
if args.local:
logger.error("Local mode and docker mode are incompatible arguments")
sys.exit(-1)
# Docker pull image
if not utils.pull_image(args.image):
raise Exception('Image %s pull failed' % image)
# TODO: Check milvus server port is available
logger.info("Init: remove all containers created with image: %s" % args.image)
utils.remove_all_containers(args.image)
runner = DockerRunner(args.image)
for operation_type in operations:
logger.info("Start run test, test type: %s" % operation_type)
run_params["params"] = operations[operation_type]
runner.run({operation_type: run_params}, run_type=args.run_type)
logger.info("Run params: %s" % str(run_params))
if args.local:
# for local mode
ip = args.ip
port = args.port
runner = LocalRunner(ip, port)
for operation_type in operations:
logger.info("Start run local mode test, test type: %s" % operation_type)
run_params["params"] = operations[operation_type]
runner.run({operation_type: run_params}, run_type=args.run_type)
logger.info("Run params: %s" % str(run_params))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,10 @@
from __future__ import absolute_import
import pdb
import time
class Base(object):
pass
class Insert(Base):
pass

View File

@ -0,0 +1,66 @@
import pdb
import logging
logger = logging.getLogger("milvus_benchmark.parser")
def operations_parser(operations, run_type="performance"):
definitions = operations[run_type]
return definitions
def table_parser(table_name):
tmp = table_name.split("_")
# if len(tmp) != 5:
# return None
data_type = tmp[0]
table_size_unit = tmp[1][-1]
table_size = tmp[1][0:-1]
if table_size_unit == "m":
table_size = int(table_size) * 1000000
elif table_size_unit == "b":
table_size = int(table_size) * 1000000000
index_file_size = int(tmp[2])
dimension = int(tmp[3])
metric_type = str(tmp[4])
return (data_type, table_size, index_file_size, dimension, metric_type)
def search_params_parser(param):
# parse top-k, set default value if top-k not in param
if "top_ks" not in param:
top_ks = [10]
else:
top_ks = param["top_ks"]
if isinstance(top_ks, int):
top_ks = [top_ks]
elif isinstance(top_ks, list):
top_ks = list(top_ks)
else:
logger.warning("Invalid format top-ks: %s" % str(top_ks))
# parse nqs, set default value if nq not in param
if "nqs" not in param:
nqs = [10]
else:
nqs = param["nqs"]
if isinstance(nqs, int):
nqs = [nqs]
elif isinstance(nqs, list):
nqs = list(nqs)
else:
logger.warning("Invalid format nqs: %s" % str(nqs))
# parse nprobes
if "nprobes" not in param:
nprobes = [1]
else:
nprobes = param["nprobes"]
if isinstance(nprobes, int):
nprobes = [nprobes]
elif isinstance(nprobes, list):
nprobes = list(nprobes)
else:
logger.warning("Invalid format nprobes: %s" % str(nprobes))
return top_ks, nqs, nprobes

View File

@ -0,0 +1,10 @@
# from tablereport import Table
# from tablereport.shortcut import write_to_excel
# RESULT_FOLDER = "results"
# def create_table(headers, bodys, table_name):
# table = Table(header=[headers],
# body=[bodys])
# write_to_excel('%s/%s.xlsx' % (RESULT_FOLDER, table_name), table)

View File

@ -0,0 +1,6 @@
numpy==1.16.3
pymilvus>=0.1.18
pyyaml==3.12
docker==4.0.2
tableprint==0.8.0
ansicolors==1.1.8

View File

@ -0,0 +1,219 @@
import os
import logging
import pdb
import time
import random
from multiprocessing import Process
import numpy as np
from client import MilvusClient
import utils
import parser
logger = logging.getLogger("milvus_benchmark.runner")
SERVER_HOST_DEFAULT = "127.0.0.1"
SERVER_PORT_DEFAULT = 19530
VECTORS_PER_FILE = 1000000
SIFT_VECTORS_PER_FILE = 100000
MAX_NQ = 10001
FILE_PREFIX = "binary_"
RANDOM_SRC_BINARY_DATA_DIR = '/tmp/random/binary_data'
SIFT_SRC_DATA_DIR = '/tmp/sift1b/query'
SIFT_SRC_BINARY_DATA_DIR = '/tmp/sift1b/binary_data'
SIFT_SRC_GROUNDTRUTH_DATA_DIR = '/tmp/sift1b/groundtruth'
WARM_TOP_K = 1
WARM_NQ = 1
DEFAULT_DIM = 512
GROUNDTRUTH_MAP = {
"1000000": "idx_1M.ivecs",
"2000000": "idx_2M.ivecs",
"5000000": "idx_5M.ivecs",
"10000000": "idx_10M.ivecs",
"20000000": "idx_20M.ivecs",
"50000000": "idx_50M.ivecs",
"100000000": "idx_100M.ivecs",
"200000000": "idx_200M.ivecs",
"500000000": "idx_500M.ivecs",
"1000000000": "idx_1000M.ivecs",
}
def gen_file_name(idx, table_dimension, data_type):
s = "%05d" % idx
fname = FILE_PREFIX + str(table_dimension) + "d_" + s + ".npy"
if data_type == "random":
fname = RANDOM_SRC_BINARY_DATA_DIR+'/'+fname
elif data_type == "sift":
fname = SIFT_SRC_BINARY_DATA_DIR+'/'+fname
return fname
def get_vectors_from_binary(nq, dimension, data_type):
# use the first file, nq should be less than VECTORS_PER_FILE
if nq > MAX_NQ:
raise Exception("Over size nq")
if data_type == "random":
file_name = gen_file_name(0, dimension, data_type)
elif data_type == "sift":
file_name = SIFT_SRC_DATA_DIR+'/'+'query.npy'
data = np.load(file_name)
vectors = data[0:nq].tolist()
return vectors
class Runner(object):
def __init__(self):
pass
def do_insert(self, milvus, table_name, data_type, dimension, size, ni):
'''
@params:
mivlus: server connect instance
dimension: table dimensionn
# index_file_size: size trigger file merge
size: row count of vectors to be insert
ni: row count of vectors to be insert each time
# store_id: if store the ids returned by call add_vectors or not
@return:
total_time: total time for all insert operation
qps: vectors added per second
ni_time: avarage insert operation time
'''
bi_res = {}
total_time = 0.0
qps = 0.0
ni_time = 0.0
if data_type == "random":
vectors_per_file = VECTORS_PER_FILE
elif data_type == "sift":
vectors_per_file = SIFT_VECTORS_PER_FILE
if size % vectors_per_file or ni > vectors_per_file:
raise Exception("Not invalid table size or ni")
file_num = size // vectors_per_file
for i in range(file_num):
file_name = gen_file_name(i, dimension, data_type)
logger.info("Load npy file: %s start" % file_name)
data = np.load(file_name)
logger.info("Load npy file: %s end" % file_name)
loops = vectors_per_file // ni
for j in range(loops):
vectors = data[j*ni:(j+1)*ni].tolist()
ni_start_time = time.time()
# start insert vectors
start_id = i * vectors_per_file + j * ni
end_id = start_id + len(vectors)
logger.info("Start id: %s, end id: %s" % (start_id, end_id))
ids = [k for k in range(start_id, end_id)]
status, ids = milvus.insert(vectors, ids=ids)
ni_end_time = time.time()
total_time = total_time + ni_end_time - ni_start_time
qps = round(size / total_time, 2)
ni_time = round(total_time / (loops * file_num), 2)
bi_res["total_time"] = round(total_time, 2)
bi_res["qps"] = qps
bi_res["ni_time"] = ni_time
return bi_res
def do_query(self, milvus, table_name, top_ks, nqs, nprobe, run_count):
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
bi_res = []
for index, nq in enumerate(nqs):
tmp_res = []
for top_k in top_ks:
avg_query_time = 0.0
total_query_time = 0.0
vectors = base_query_vectors[0:nq]
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
for i in range(run_count):
logger.info("Start run query, run %d of %s" % (i+1, run_count))
start_time = time.time()
status, query_res = milvus.query(vectors, top_k, nprobe)
total_query_time = total_query_time + (time.time() - start_time)
if status.code:
logger.error("Query failed with message: %s" % status.message)
avg_query_time = round(total_query_time / run_count, 2)
logger.info("Avarage query time: %.2f" % avg_query_time)
tmp_res.append(avg_query_time)
bi_res.append(tmp_res)
return bi_res
def do_query_ids(self, milvus, table_name, top_k, nq, nprobe):
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
vectors = base_query_vectors[0:nq]
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
status, query_res = milvus.query(vectors, top_k, nprobe)
if not status.OK():
msg = "Query failed with message: %s" % status.message
raise Exception(msg)
result_ids = []
for result in query_res:
tmp = []
for item in result:
tmp.append(item.id)
result_ids.append(tmp)
return result_ids
def do_query_acc(self, milvus, table_name, top_k, nq, nprobe, id_store_name):
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
vectors = base_query_vectors[0:nq]
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
status, query_res = milvus.query(vectors, top_k, nprobe)
if not status.OK():
msg = "Query failed with message: %s" % status.message
raise Exception(msg)
# if file existed, cover it
if os.path.isfile(id_store_name):
os.remove(id_store_name)
with open(id_store_name, 'a+') as fd:
for nq_item in query_res:
for item in nq_item:
fd.write(str(item.id)+'\t')
fd.write('\n')
# compute and print accuracy
def compute_accuracy(self, flat_file_name, index_file_name):
flat_id_list = []; index_id_list = []
logger.info("Loading flat id file: %s" % flat_file_name)
with open(flat_file_name, 'r') as flat_id_fd:
for line in flat_id_fd:
tmp_list = line.strip("\n").strip().split("\t")
flat_id_list.append(tmp_list)
logger.info("Loading index id file: %s" % index_file_name)
with open(index_file_name) as index_id_fd:
for line in index_id_fd:
tmp_list = line.strip("\n").strip().split("\t")
index_id_list.append(tmp_list)
if len(flat_id_list) != len(index_id_list):
raise Exception("Flat index result length: <flat: %s, index: %s> not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list)))
# get the accuracy
return self.get_recall_value(flat_id_list, index_id_list)
def get_recall_value(self, flat_id_list, index_id_list):
"""
Use the intersection length
"""
sum_radio = 0.0
for index, item in enumerate(index_id_list):
tmp = set(item).intersection(set(flat_id_list[index]))
sum_radio = sum_radio + len(tmp) / len(item)
return round(sum_radio / len(index_id_list), 3)
"""
Implementation based on:
https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py
"""
def get_groundtruth_ids(self, table_size):
fname = GROUNDTRUTH_MAP[str(table_size)]
fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname
a = np.fromfile(fname, dtype='int32')
d = a[0]
true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
return true_ids

View File

@ -0,0 +1,38 @@
# data sets
datasets:
hf5:
gist-960,sift-128
npy:
50000000-512, 100000000-512
operations:
# interface: search_vectors
query:
# dataset: table name you have already created
# key starts with "server." need to reconfig and restart server, including nprpbe/nlist/use_blas_threshold/..
[
# debug
# {"dataset": "ip_ivfsq8_1000", "top_ks": [16], "nqs": [1], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
{"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
{"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
{"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
{"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
{"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
# {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], "nqs": [1, 10, 100, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
]
# interface: add_vectors
insert:
# index_type: flat/ivf_flat/ivf_sq8
[
# debug
{"table_name": "ip_ivf_flat_20m_1024", "table.index_type": "ivf_flat", "server.index_building_threshold": 1024, "table.size": 20000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110},
{"table_name": "ip_ivf_sq8_50m_1024", "table.index_type": "ivf_sq8", "server.index_building_threshold": 1024, "table.size": 50000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110},
]
# TODO: interface: build_index
build: []

View File

@ -0,0 +1,121 @@
accuracy:
# interface: search_vectors
query:
[
{
"dataset": "random_20m_1024_512_ip",
# index info
"index.index_types": ["flat", "ivf_sq8"],
"index.nlists": [16384],
"index.metric_types": ["ip"],
"nprobes": [1, 16, 64],
"top_ks": [64],
"nqs": [100],
"server.cpu_cache_capacity": 100,
"server.resources": ["cpu", "gpu0"],
"db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip",
},
# {
# "dataset": "sift_50m_1024_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "index.metric_types": ["l2"],
# "nprobes": [1, 16, 64],
# "top_ks": [64],
# "nqs": [100],
# "server.cpu_cache_capacity": 160,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2",
# "sift_acc": true
# },
# {
# "dataset": "sift_50m_1024_128_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "index.metric_types": ["l2"],
# "nprobes": [1, 16, 64],
# "top_ks": [64],
# "nqs": [100],
# "server.cpu_cache_capacity": 160,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2_sq8",
# "sift_acc": true
# },
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "index.metric_types": ["l2"],
# "nprobes": [1, 16, 64, 128],
# "top_ks": [64],
# "nqs": [100],
# "server.cpu_cache_capacity": 200,
# "server.resources": ["cpu"],
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
# "sift_acc": true
# },
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "index.metric_types": ["l2"],
# "nprobes": [1, 16, 64, 128],
# "top_ks": [64],
# "nqs": [100],
# "server.cpu_cache_capacity": 200,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
# "sift_acc": true
# },
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "index.metric_types": ["l2"],
# "nprobes": [1, 16, 64, 128],
# "top_ks": [64],
# "nqs": [100],
# "server.cpu_cache_capacity": 200,
# "server.resources": ["cpu", "gpu0", "gpu1"],
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
# "sift_acc": true
# },
# {
# "dataset": "sift_1m_1024_128_l2",
# "index.index_types": ["flat", "ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1, 32, 128, 256, 512],
# "nqs": 10,
# "top_ks": 10,
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 16,
# },
# {
# "dataset": "sift_10m_1024_128_l2",
# "index.index_types": ["flat", "ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1, 32, 128, 256, 512],
# "nqs": 10,
# "top_ks": 10,
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 32,
# },
# {
# "dataset": "sift_50m_1024_128_l2",
# "index.index_types": ["flat", "ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1, 32, 128, 256, 512],
# "nqs": 10,
# "top_ks": 10,
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 64,
# }
]

View File

@ -0,0 +1,258 @@
performance:
# interface: add_vectors
insert:
# index_type: flat/ivf_flat/ivf_sq8/mix_nsg
[
# debug
# data_type / data_size / index_file_size / dimension
# data_type: random / ann_sift
# data_size: 10m / 1b
# {
# "table_name": "random_50m_1024_512_ip",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# "server.cpu_cache_capacity": 16,
# # "server.resources": ["gpu0", "gpu1"],
# "db_path_prefix": "/test/milvus/db_data"
# },
# {
# "table_name": "random_5m_1024_512_ip",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# "server.cpu_cache_capacity": 16,
# "server.resources": ["gpu0", "gpu1"],
# "db_path_prefix": "/test/milvus/db_data/random_5m_1024_512_ip"
# },
# {
# "table_name": "sift_1m_50_128_l2",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# # "server.cpu_cache_capacity": 16,
# "db_path_prefix": "/test/milvus/db_data"
# },
# {
# "table_name": "sift_1m_256_128_l2",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# # "server.cpu_cache_capacity": 16,
# "db_path_prefix": "/test/milvus/db_data"
# }
# {
# "table_name": "sift_50m_1024_128_l2",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# # "server.cpu_cache_capacity": 16,
# },
# {
# "table_name": "sift_100m_1024_128_l2",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# },
# {
# "table_name": "sift_1b_2048_128_l2",
# "ni_per": 100000,
# "processes": 5, # multiprocessing
# "server.cpu_cache_capacity": 16,
# }
]
# interface: search_vectors
query:
# dataset: table name you have already created
# key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity ..
[
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "nprobes": [8, 32],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
# "nqs": [1, 10, 100, 500, 1000],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 200,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h"
# },
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [8, 32],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
# "nqs": [1, 10, 100, 500, 1000],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 200,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2"
# },
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "nprobes": [8, 32],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
# "nqs": [1, 10, 100, 500, 1000],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 200,
# "server.resources": ["cpu"],
# "db_path_prefix": "/test/milvus/db_data"
# },
{
"dataset": "random_50m_1024_512_ip",
"index.index_types": ["ivf_sq8h"],
"index.nlists": [16384],
"nprobes": [8],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
"top_ks": [512],
# "nqs": [1, 10, 100, 500, 1000],
"nqs": [500],
"server.use_blas_threshold": 1100,
"server.cpu_cache_capacity": 150,
"server.gpu_cache_capacity": 6,
"server.resources": ["cpu", "gpu0", "gpu1"],
"db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip"
},
# {
# "dataset": "random_50m_1024_512_ip",
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [8, 32],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
# "nqs": [1, 10, 100, 500, 1000],
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 150,
# "server.resources": ["cpu", "gpu0", "gpu1"],
# "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip_sq8"
# },
# {
# "dataset": "random_20m_1024_512_ip",
# "index.index_types": ["flat"],
# "index.nlists": [16384],
# "nprobes": [50],
# "top_ks": [64],
# "nqs": [10],
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 100,
# "server.resources": ["cpu", "gpu0", "gpu1"],
# "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip"
# },
# {
# "dataset": "random_100m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [8, 32],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
# "nqs": [1, 10, 100, 500, 1000],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 250,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data"
# },
# {
# "dataset": "random_100m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [8, 32],
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
# "nqs": [1, 10, 100, 500, 1000],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 250,
# "server.resources": ["cpu"],
# "db_path_prefix": "/test/milvus/db_data"
# },
# {
# "dataset": "random_10m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# # "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 16,
# },
# {
# "dataset": "random_10m_1024_512_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 64
# },
# {
# "dataset": "sift_500m_1024_128_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 8, 16, 64, 256, 512, 1000],
# "nqs": [1, 100, 500, 800, 1000, 1500],
# # "top_ks": [256],
# # "nqs": [800],
# "processes": 1, # multiprocessing
# # "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 120,
# "server.resources": ["gpu0", "gpu1"],
# "db_path_prefix": "/test/milvus/db_data"
# },
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8h"],
# "index.nlists": [16384],
# "nprobes": [1],
# # "top_ks": [1],
# # "nqs": [1],
# "top_ks": [256],
# "nqs": [800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 110,
# "server.resources": ["cpu", "gpu0"],
# "db_path_prefix": "/test/milvus/db_data"
# },
# {
# "dataset": "random_50m_1024_512_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# # "top_ks": [256],
# # "nqs": [800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 128
# },
# [
# {
# "dataset": "sift_1m_50_128_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1],
# "nqs": [1],
# "db_path_prefix": "/test/milvus/db_data"
# # "processes": 1, # multiprocessing
# # "server.use_blas_threshold": 1100,
# # "server.cpu_cache_capacity": 256
# }
]

View File

@ -0,0 +1,17 @@
stability:
# interface: search_vectors / add_vectors mix operation
query:
[
{
"dataset": "random_20m_1024_512_ip",
# "nqs": [1, 10, 100, 1000, 10000],
# "pds": [0.1, 0.44, 0.44, 0.02],
"query_process_num": 10,
# each 10s, do an insertion
# "insert_interval": 1,
# minutes
"during_time": 360,
"server.cpu_cache_capacity": 100
},
]

View File

@ -0,0 +1,171 @@
#"server.resources": ["gpu0", "gpu1"]
performance:
# interface: search_vectors
query:
# dataset: table name you have already created
# key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity ..
[
# debug
# {
# "dataset": "random_10m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 16,
# },
# {
# "dataset": "random_10m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 16,
# },
# {
# "dataset": "random_10m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# # "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 16,
# },
# {
# "dataset": "random_10m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# # "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 16,
# },
# {
# "dataset": "random_10m_1024_512_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 64
# },
# {
# "dataset": "sift_50m_1024_128_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1, 32, 128],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# # "top_ks": [256],
# # "nqs": [800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 310,
# "server.resources": ["gpu0", "gpu1"]
# },
{
"dataset": "sift_1m_1024_128_l2",
# index info
"index.index_types": ["ivf_sq8"],
"index.nlists": [16384],
"nprobes": [32],
"top_ks": [10],
"nqs": [100],
# "top_ks": [256],
# "nqs": [800],
"processes": 1, # multiprocessing
"server.use_blas_threshold": 1100,
"server.cpu_cache_capacity": 310,
"server.resources": ["cpu"]
},
{
"dataset": "sift_1m_1024_128_l2",
# index info
"index.index_types": ["ivf_sq8"],
"index.nlists": [16384],
"nprobes": [32],
"top_ks": [10],
"nqs": [100],
# "top_ks": [256],
# "nqs": [800],
"processes": 1, # multiprocessing
"server.use_blas_threshold": 1100,
"server.cpu_cache_capacity": 310,
"server.resources": ["gpu0"]
},
{
"dataset": "sift_1m_1024_128_l2",
# index info
"index.index_types": ["ivf_sq8"],
"index.nlists": [16384],
"nprobes": [32],
"top_ks": [10],
"nqs": [100],
# "top_ks": [256],
# "nqs": [800],
"processes": 1, # multiprocessing
"server.use_blas_threshold": 1100,
"server.cpu_cache_capacity": 310,
"server.resources": ["gpu0", "gpu1"]
},
# {
# "dataset": "sift_1b_2048_128_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# # "top_ks": [256],
# # "nqs": [800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 310
# },
# {
# "dataset": "random_50m_1024_512_l2",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# # "top_ks": [256],
# # "nqs": [800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 128,
# "server.resources": ["gpu0", "gpu1"]
# },
# {
# "dataset": "random_100m_1024_512_ip",
# # index info
# "index.index_types": ["ivf_sq8"],
# "index.nlists": [16384],
# "nprobes": [1],
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
# "nqs": [1, 10, 100, 500, 800],
# "processes": 1, # multiprocessing
# "server.use_blas_threshold": 1100,
# "server.cpu_cache_capacity": 256
# },
]

View File

@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
__true_print = print # noqa
import os
import sys
import pdb
import time
import datetime
import argparse
import threading
import logging
import docker
import multiprocessing
import numpy
# import psutil
from yaml import load, dump
import tableprint as tp
logger = logging.getLogger("milvus_benchmark.utils")
MULTI_DB_SLAVE_PATH = "/opt/milvus/data2;/opt/milvus/data3"
def get_current_time():
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
def print_table(headers, columns, data):
bodys = []
for index, value in enumerate(columns):
tmp = [value]
tmp.extend(data[index])
bodys.append(tmp)
tp.table(bodys, headers)
def modify_config(k, v, type=None, file_path="conf/server_config.yaml", db_slave=None):
if not os.path.isfile(file_path):
raise Exception('File: %s not found' % file_path)
with open(file_path) as f:
config_dict = load(f)
f.close()
if config_dict:
if k.find("use_blas_threshold") != -1:
config_dict['engine_config']['use_blas_threshold'] = int(v)
elif k.find("cpu_cache_capacity") != -1:
config_dict['cache_config']['cpu_cache_capacity'] = int(v)
elif k.find("gpu_cache_capacity") != -1:
config_dict['cache_config']['gpu_cache_capacity'] = int(v)
elif k.find("resource_pool") != -1:
config_dict['resource_config']['resource_pool'] = v
if db_slave:
config_dict['db_config']['db_slave_path'] = MULTI_DB_SLAVE_PATH
with open(file_path, 'w') as f:
dump(config_dict, f, default_flow_style=False)
f.close()
else:
raise Exception('Load file:%s error' % file_path)
def pull_image(image):
registry = image.split(":")[0]
image_tag = image.split(":")[1]
client = docker.APIClient(base_url='unix://var/run/docker.sock')
logger.info("Start pulling image: %s" % image)
return client.pull(registry, image_tag)
def run_server(image, mem_limit=None, timeout=30, test_type="local", volume_name=None, db_slave=None):
import colors
client = docker.from_env()
# if mem_limit is None:
# mem_limit = psutil.virtual_memory().available
# logger.info('Memory limit:', mem_limit)
# cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1)
# logger.info('Running on CPUs:', cpu_limit)
for dir_item in ['logs', 'db']:
try:
os.mkdir(os.path.abspath(dir_item))
except Exception as e:
pass
if test_type == "local":
volumes = {
os.path.abspath('conf'):
{'bind': '/opt/milvus/conf', 'mode': 'ro'},
os.path.abspath('logs'):
{'bind': '/opt/milvus/logs', 'mode': 'rw'},
os.path.abspath('db'):
{'bind': '/opt/milvus/db', 'mode': 'rw'},
}
elif test_type == "remote":
if volume_name is None:
raise Exception("No volume name")
remote_log_dir = volume_name+'/logs'
remote_db_dir = volume_name+'/db'
for dir_item in [remote_log_dir, remote_db_dir]:
if not os.path.isdir(dir_item):
os.makedirs(dir_item, exist_ok=True)
volumes = {
os.path.abspath('conf'):
{'bind': '/opt/milvus/conf', 'mode': 'ro'},
remote_log_dir:
{'bind': '/opt/milvus/logs', 'mode': 'rw'},
remote_db_dir:
{'bind': '/opt/milvus/db', 'mode': 'rw'}
}
# add volumes
if db_slave and isinstance(db_slave, int):
for i in range(2, db_slave+1):
remote_db_dir = volume_name+'/data'+str(i)
if not os.path.isdir(remote_db_dir):
os.makedirs(remote_db_dir, exist_ok=True)
volumes[remote_db_dir] = {'bind': '/opt/milvus/data'+str(i), 'mode': 'rw'}
container = client.containers.run(
image,
volumes=volumes,
runtime="nvidia",
ports={'19530/tcp': 19530, '8080/tcp': 8080},
environment=["OMP_NUM_THREADS=48"],
# cpuset_cpus=cpu_limit,
# mem_limit=mem_limit,
# environment=[""],
detach=True)
def stream_logs():
for line in container.logs(stream=True):
logger.info(colors.color(line.decode().rstrip(), fg='blue'))
if sys.version_info >= (3, 0):
t = threading.Thread(target=stream_logs, daemon=True)
else:
t = threading.Thread(target=stream_logs)
t.daemon = True
t.start()
logger.info('Container: %s started' % container)
return container
# exit_code = container.wait(timeout=timeout)
# # Exit if exit code
# if exit_code == 0:
# return container
# elif exit_code is not None:
# print(colors.color(container.logs().decode(), fg='red'))
# raise Exception('Child process raised exception %s' % str(exit_code))
def restart_server(container):
client = docker.APIClient(base_url='unix://var/run/docker.sock')
client.restart(container.name)
logger.info('Container: %s restarted' % container.name)
return container
def remove_container(container):
container.remove(force=True)
logger.info('Container: %s removed' % container)
def remove_all_containers(image):
client = docker.from_env()
try:
for container in client.containers.list():
if image in container.image.tags:
container.stop(timeout=30)
container.remove(force=True)
except Exception as e:
logger.error("Containers removed failed")
def container_exists(image):
'''
Check if container existed with the given image name
@params: image name
@return: container if exists
'''
res = False
client = docker.from_env()
for container in client.containers.list():
if image in container.image.tags:
# True
res = container
return res
if __name__ == '__main__':
# print(pull_image('branch-0.3.1-debug'))
stop_server()

View File

@ -0,0 +1,14 @@
node_modules
npm-debug.log
Dockerfile*
docker-compose*
.dockerignore
.git
.gitignore
.env
*/bin
*/obj
README.md
LICENSE
.vscode
__pycache__

13
tests/milvus_python_test/.gitignore vendored Normal file
View File

@ -0,0 +1,13 @@
.python-version
.pytest_cache
__pycache__
.vscode
.idea
test_out/
*.pyc
db/
logs/
.coverage

View File

@ -0,0 +1,14 @@
FROM python:3.6.8-jessie
LABEL Name=megasearch_engine_test Version=0.0.1
WORKDIR /app
ADD . /app
RUN apt-get update && apt-get install -y --no-install-recommends \
libc-dev build-essential && \
python3 -m pip install -r requirements.txt && \
apt-get remove --purge -y
ENTRYPOINT [ "/app/docker-entrypoint.sh" ]
CMD [ "start" ]

View File

@ -0,0 +1,143 @@
# Milvus test cases
## * Interfaces test
### 1. 连接测试
#### 1.1 连接
| cases | expected |
| ---------------- | -------------------------------------------- |
| 非法IP 123.0.0.2 | method: connect raise error in given timeout |
| 正常 uri | attr: connected assert true |
| 非法 uri | method: connect raise error in given timeout |
| 最大连接数 | all connection attrs: connected assert true |
| | |
#### 1.2 断开连接
| cases | expected |
| ------------------------ | ------------------- |
| 正常连接下,断开连接 | connect raise error |
| 正常连接下,重复断开连接 | connect raise error |
### 2. Table operation
#### 2.1 表创建
##### 2.1.1 表名
| cases | expected |
| ------------------------- | ----------- |
| 基础功能,参数正常 | status pass |
| 表名已存在 | status fail |
| 表名:"中文" | status pass |
| 表名带特殊字符: "-39fsd-" | status pass |
| 表名带空格: "test1 2" | status pass |
| invalid dim: 0 | raise error |
| invalid dim: -1 | raise error |
| invalid dim: 100000000 | raise error |
| invalid dim: "string" | raise error |
| index_type: 0 | status pass |
| index_type: 1 | status pass |
| index_type: 2 | status pass |
| index_type: string | raise error |
| | |
##### 2.1.2 维数支持
| cases | expected |
| --------------------- | ----------- |
| 维数: 0 | raise error |
| 维数负数: -1 | raise error |
| 维数最大值: 100000000 | raise error |
| 维数字符串: "string" | raise error |
| | |
##### 2.1.3 索引类型支持
| cases | expected |
| ---------------- | ----------- |
| 索引类型: 0 | status pass |
| 索引类型: 1 | status pass |
| 索引类型: 2 | status pass |
| 索引类型: string | raise error |
| | |
#### 2.2 表说明
| cases | expected |
| ---------------------- | -------------------------------- |
| 创建表后执行describe | 返回结构体,元素与创建表参数一致 |
| | |
#### 2.3 表删除
| cases | expected |
| -------------- | ---------------------- |
| 删除已存在表名 | has_table return False |
| 删除不存在表名 | status fail |
| | |
#### 2.4 表是否存在
| cases | expected |
| ----------------------- | ------------ |
| 存在表调用has_table | assert true |
| 不存在表调用has_table | assert false |
| | |
#### 2.5 查询表记录条数
| cases | expected |
| -------------------- | ------------------------ |
| 空表 | 0 |
| 空表插入数据(单条) | 1 |
| 空表插入数据(多条) | assert length of vectors |
#### 2.6 查询表数量
| cases | expected |
| --------------------------------------------- | -------------------------------- |
| 两张表一张空表一张有数据调用show tables | assert length of table list == 2 |
| | |
### 3. Add vectors
| interfaces | cases | expected |
| ----------- | --------------------------------------------------------- | ------------------------------------ |
| add_vectors | add basic | assert length of ids == nq |
| | add vectors into table not existed | status fail |
| | dim not match: single vector | status fail |
| | dim not match: vector list | status fail |
| | single vector element empty | status fail |
| | vector list element empty | status fail |
| | query immediately after adding | status pass |
| | query immediately after sleep 6s | status pass && length of result == 1 |
| | concurrent add with multi threads(share one connection) | status pass |
| | concurrent add with multi threads(independent connection) | status pass |
| | concurrent add with multi process(independent connection) | status pass |
| | index_type: 2 | status pass |
| | index_type: string | raise error |
| | | |
### 4. Search vectors
| interfaces | cases | expected |
| -------------- | ------------------------------------------------- | -------------------------------- |
| search_vectors | search basic(query vector in vectors, top-k<nq) | assert length of result == nq |
| | search vectors into table not existed | status fail |
| | basic top-k | score of query vectors == 100.0 |
| | invalid top-k: 0 | raise error |
| | invalid top-k: -1 | raise error |
| | invalid top-k: "string" | raise error |
| | top-k > nq | assert length of result == nq |
| | concurrent search | status pass |
| | query_range(get_current_day(), get_current_day()) | assert length of result == nq |
| | invalid query_range: "" | raise error |
| | query_range(get_last_day(2), get_last_day(1)) | assert length of result == 0 |
| | query_range(get_last_day(2), get_current_day()) | assert length of result == nq |
| | query_range((get_last_day(2), get_next_day(2)) | assert length of result == nq |
| | query_range((get_current_day(), get_next_day(2)) | assert length of result == nq |
| | query_range(get_next_day(1), get_next_day(2)) | assert length of result == 0 |
| | score: vector[i] = vector[i]+-0.01 | score > 99.9 |

View File

@ -0,0 +1,14 @@
# Requirements
* python 3.6.8
# How to use this Test Project
```shell
pytest . -q -v
```
with allure test report
```shell
pytest --alluredir=test_out . -q -v
allure serve test_out
```
# Contribution getting started
* Follow PEP-8 for naming and black for formatting.

View File

@ -0,0 +1,27 @@
* GLOBAL:
FORMAT = "%datetime | %level | %logger | %msg"
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log"
ENABLED = true
TO_FILE = true
TO_STANDARD_OUTPUT = false
SUBSECOND_PRECISION = 3
PERFORMANCE_TRACKING = false
MAX_LOG_FILE_SIZE = 209715200 ## Throw log files away after 200MB
* DEBUG:
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log"
ENABLED = true
* WARNING:
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log"
* TRACE:
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log"
* VERBOSE:
FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
TO_FILE = false
TO_STANDARD_OUTPUT = false
## Error logs
* ERROR:
ENABLED = true
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log"
* FATAL:
ENABLED = true
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log"

View File

@ -0,0 +1,32 @@
server_config:
address: 0.0.0.0
port: 19530
deploy_mode: single
time_zone: UTC+8
db_config:
primary_path: /opt/milvus
secondary_path:
backend_url: sqlite://:@:/
insert_buffer_size: 4
build_index_gpu: 0
preload_table:
metric_config:
enable_monitor: true
collector: prometheus
prometheus_config:
port: 8080
cache_config:
cpu_cache_capacity: 8
cpu_cache_threshold: 0.85
cache_insert_data: false
engine_config:
use_blas_threshold: 20
resource_config:
resource_pool:
- cpu
- gpu0

View File

@ -0,0 +1,128 @@
import socket
import pdb
import logging
import pytest
from utils import gen_unique_str
from milvus import Milvus, IndexType, MetricType
index_file_size = 10
def pytest_addoption(parser):
parser.addoption("--ip", action="store", default="localhost")
parser.addoption("--port", action="store", default=19530)
def check_server_connection(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
connected = True
if ip and (ip not in ['localhost', '127.0.0.1']):
try:
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
except Exception as e:
print("Socket connnet failed: %s" % str(e))
connected = False
return connected
def get_args(request):
args = {
"ip": request.config.getoption("--ip"),
"port": request.config.getoption("--port")
}
return args
@pytest.fixture(scope="module")
def connect(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
milvus = Milvus()
try:
milvus.connect(host=ip, port=port)
except:
pytest.exit("Milvus server can not connected, exit pytest ...")
def fin():
try:
milvus.disconnect()
except:
pass
request.addfinalizer(fin)
return milvus
@pytest.fixture(scope="module")
def dis_connect(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
milvus = Milvus()
milvus.connect(host=ip, port=port)
milvus.disconnect()
def fin():
try:
milvus.disconnect()
except:
pass
request.addfinalizer(fin)
return milvus
@pytest.fixture(scope="module")
def args(request):
ip = request.config.getoption("--ip")
port = request.config.getoption("--port")
args = {"ip": ip, "port": port}
return args
@pytest.fixture(scope="function")
def table(request, connect):
ori_table_name = getattr(request.module, "table_id", "test")
table_name = gen_unique_str(ori_table_name)
dim = getattr(request.module, "dim", "128")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
status = connect.create_table(param)
# logging.getLogger().info(status)
if not status.OK():
pytest.exit("Table can not be created, exit pytest ...")
def teardown():
status, table_names = connect.show_tables()
for table_name in table_names:
connect.delete_table(table_name)
request.addfinalizer(teardown)
return table_name
@pytest.fixture(scope="function")
def ip_table(request, connect):
ori_table_name = getattr(request.module, "table_id", "test")
table_name = gen_unique_str(ori_table_name)
dim = getattr(request.module, "dim", "128")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
status = connect.create_table(param)
# logging.getLogger().info(status)
if not status.OK():
pytest.exit("Table can not be created, exit pytest ...")
def teardown():
status, table_names = connect.show_tables()
for table_name in table_names:
connect.delete_table(table_name)
request.addfinalizer(teardown)
return table_name

View File

@ -0,0 +1,9 @@
#!/bin/bash
set -e
if [ "$1" = 'start' ]; then
tail -f /dev/null
fi
exec "$@"

View File

@ -0,0 +1,9 @@
[pytest]
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_cli = true
log_level = 20
timeout = 300
level = 1

View File

@ -0,0 +1,25 @@
astroid==2.2.5
atomicwrites==1.3.0
attrs==19.1.0
importlib-metadata==0.15
isort==4.3.20
lazy-object-proxy==1.4.1
mccabe==0.6.1
more-itertools==7.0.0
numpy==1.16.3
pluggy==0.12.0
py==1.8.0
pylint==2.3.1
pytest==4.5.0
pytest-timeout==1.3.3
pytest-repeat==0.8.0
allure-pytest==2.7.0
pytest-print==0.1.2
pytest-level==0.1.1
six==1.12.0
thrift==0.11.0
typed-ast==1.3.5
wcwidth==0.1.7
wrapt==1.11.1
zipp==0.5.1
pymilvus-test>=0.2.0

View File

@ -0,0 +1,25 @@
astroid==2.2.5
atomicwrites==1.3.0
attrs==19.1.0
importlib-metadata==0.15
isort==4.3.20
lazy-object-proxy==1.4.1
mccabe==0.6.1
more-itertools==7.0.0
numpy==1.16.3
pluggy==0.12.0
py==1.8.0
pylint==2.3.1
pytest==4.5.0
pytest-timeout==1.3.3
pytest-repeat==0.8.0
allure-pytest==2.7.0
pytest-print==0.1.2
pytest-level==0.1.1
six==1.12.0
thrift==0.11.0
typed-ast==1.3.5
wcwidth==0.1.7
wrapt==1.11.1
zipp==0.5.1
pymilvus>=0.1.24

View File

@ -0,0 +1,24 @@
astroid==2.2.5
atomicwrites==1.3.0
attrs==19.1.0
importlib-metadata==0.15
isort==4.3.20
lazy-object-proxy==1.4.1
mccabe==0.6.1
more-itertools==7.0.0
numpy==1.16.3
pluggy==0.12.0
py==1.8.0
pylint==2.3.1
pytest==4.5.0
pytest-timeout==1.3.3
pytest-repeat==0.8.0
allure-pytest==2.7.0
pytest-print==0.1.2
pytest-level==0.1.1
six==1.12.0
thrift==0.11.0
typed-ast==1.3.5
wcwidth==0.1.7
wrapt==1.11.1
zipp==0.5.1

View File

@ -0,0 +1,4 @@
#/bin/bash
pytest . $@

View File

@ -0,0 +1,41 @@
'''
Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
Unauthorized copying of this file, via any medium is strictly prohibited.
Proprietary and confidential.
'''
'''
Test Description:
This document is only a template to show how to write a auto-test script
本文档仅仅是个展示如何编写自动化测试脚本的模板
'''
import pytest
from milvus import Milvus
class TestConnection:
def test_connect_localhost(self):
"""
TestCase1.1
Test target: This case is to check if the server can be connected.
Test method: Call API: milvus.connect to connect local milvus server, ip address: 127.0.0.1 and port: 19530, check the return status
Expectation: Return status is OK.
测试目的:本用例测试客户端是否可以与服务器建立连接
测试方法调用SDK API: milvus.connect方法连接本地服务器IP地址127.0.0.1端口19530检查调用返回状态
期望结果返回状态是OK
"""
milvus = Milvus()
milvus.connect(host='127.0.0.1', port='19530')
assert milvus.connected

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,386 @@
import pytest
from milvus import Milvus
import pdb
import threading
from multiprocessing import Process
from utils import *
__version__ = '0.5.0'
CONNECT_TIMEOUT = 12
class TestConnect:
def local_ip(self, args):
'''
check if ip is localhost or not
'''
if not args["ip"] or args["ip"] == 'localhost' or args["ip"] == "127.0.0.1":
return True
else:
return False
def test_disconnect(self, connect):
'''
target: test disconnect
method: disconnect a connected client
expected: connect failed after disconnected
'''
res = connect.disconnect()
assert res.OK()
with pytest.raises(Exception) as e:
res = connect.server_version()
def test_disconnect_repeatedly(self, connect, args):
'''
target: test disconnect repeatedly
method: disconnect a connected client, disconnect again
expected: raise an error after disconnected
'''
if not connect.connected():
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value)
res = milvus.disconnect()
with pytest.raises(Exception) as e:
res = milvus.disconnect()
else:
res = connect.disconnect()
with pytest.raises(Exception) as e:
res = connect.disconnect()
def test_connect_correct_ip_port(self, args):
'''
target: test connect with corrent ip and port value
method: set correct ip and port
expected: connected is True
'''
milvus = Milvus()
milvus.connect(host=args["ip"], port=args["port"])
assert milvus.connected()
def test_connect_connected(self, args):
'''
target: test connect and disconnect with corrent ip and port value, assert connected value
method: set correct ip and port
expected: connected is False
'''
milvus = Milvus()
milvus.connect(host=args["ip"], port=args["port"])
milvus.disconnect()
assert not milvus.connected()
# TODO: Currently we test with remote IP, localhost testing need to add
def _test_connect_ip_localhost(self, args):
'''
target: test connect with ip value: localhost
method: set host localhost
expected: connected is True
'''
milvus = Milvus()
milvus.connect(host='localhost', port=args["port"])
assert milvus.connected()
@pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_wrong_ip_null(self, args):
'''
target: test connect with wrong ip value
method: set host null
expected: not use default ip, connected is False
'''
milvus = Milvus()
ip = ""
with pytest.raises(Exception) as e:
milvus.connect(host=ip, port=args["port"], timeout=1)
assert not milvus.connected()
def test_connect_uri(self, args):
'''
target: test connect with correct uri
method: uri format and value are both correct
expected: connected is True
'''
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value)
assert milvus.connected()
def test_connect_uri_null(self, args):
'''
target: test connect with null uri
method: uri set null
expected: connected is True
'''
milvus = Milvus()
uri_value = ""
if self.local_ip(args):
milvus.connect(uri=uri_value, timeout=1)
assert milvus.connected()
else:
with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value, timeout=1)
assert not milvus.connected()
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_wrong_uri_wrong_port_null(self, args):
'''
target: test uri connect with port value wouldn't connected
method: set uri port null
expected: connected is True
'''
milvus = Milvus()
uri_value = "tcp://%s:" % args["ip"]
with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value, timeout=1)
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_wrong_uri_wrong_ip_null(self, args):
'''
target: test uri connect with ip value wouldn't connected
method: set uri ip null
expected: connected is True
'''
milvus = Milvus()
uri_value = "tcp://:%s" % args["port"]
with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value, timeout=1)
assert not milvus.connected()
# TODO: enable
def _test_connect_with_multiprocess(self, args):
'''
target: test uri connect with multiprocess
method: set correct uri, test with multiprocessing connecting
expected: all connection is connected
'''
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
process_num = 4
processes = []
def connect(milvus):
milvus.connect(uri=uri_value)
with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value)
assert milvus.connected()
for i in range(process_num):
milvus = Milvus()
p = Process(target=connect, args=(milvus, ))
processes.append(p)
p.start()
for p in processes:
p.join()
def test_connect_repeatedly(self, args):
'''
target: test connect repeatedly
method: connect again
expected: status.code is 0, and status.message shows have connected already
'''
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value)
milvus.connect(uri=uri_value)
assert milvus.connected()
def test_connect_disconnect_repeatedly_once(self, args):
'''
target: test connect and disconnect repeatedly
method: disconnect, and then connect, assert connect status
expected: status.code is 0
'''
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value)
milvus.disconnect()
milvus.connect(uri=uri_value)
assert milvus.connected()
def test_connect_disconnect_repeatedly_times(self, args):
'''
target: test connect and disconnect for 10 times repeatedly
method: disconnect, and then connect, assert connect status
expected: status.code is 0
'''
times = 10
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value)
for i in range(times):
milvus.disconnect()
milvus.connect(uri=uri_value)
assert milvus.connected()
# TODO: enable
def _test_connect_disconnect_with_multiprocess(self, args):
'''
target: test uri connect and disconnect repeatly with multiprocess
method: set correct uri, test with multiprocessing connecting and disconnecting
expected: all connection is connected after 10 times operation
'''
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
process_num = 4
processes = []
def connect(milvus):
milvus.connect(uri=uri_value)
milvus.disconnect()
milvus.connect(uri=uri_value)
assert milvus.connected()
for i in range(process_num):
milvus = Milvus()
p = Process(target=connect, args=(milvus, ))
processes.append(p)
p.start()
for p in processes:
p.join()
def test_connect_param_priority_no_port(self, args):
'''
target: both host_ip_port / uri are both given, if port is null, use the uri params
method: port set "", check if wrong uri connection is ok
expected: connect raise an exception and connected is false
'''
milvus = Milvus()
uri_value = "tcp://%s:19540" % args["ip"]
milvus.connect(host=args["ip"], port="", uri=uri_value)
assert milvus.connected()
def test_connect_param_priority_uri(self, args):
'''
target: both host_ip_port / uri are both given, if host is null, use the uri params
method: host set "", check if correct uri connection is ok
expected: connected is False
'''
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
with pytest.raises(Exception) as e:
milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1)
assert not milvus.connected()
def test_connect_param_priority_both_hostip_uri(self, args):
'''
target: both host_ip_port / uri are both given, and not null, use the uri params
method: check if wrong uri connection is ok
expected: connect raise an exception and connected is false
'''
milvus = Milvus()
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
with pytest.raises(Exception) as e:
milvus.connect(host=args["ip"], port=19540, uri=uri_value, timeout=1)
assert not milvus.connected()
def _test_add_vector_and_disconnect_concurrently(self):
'''
Target: test disconnect in the middle of add vectors
Method:
a. use coroutine or multi-processing, to simulate network crashing
b. data_set not too large incase disconnection happens when data is underd-preparing
c. data_set not too small incase disconnection happens when data has already been transferred
d. make sure disconnection happens when data is in-transport
Expected: Failure, get_table_row_count == 0
'''
pass
def _test_search_vector_and_disconnect_concurrently(self):
'''
Target: Test disconnect in the middle of search vectors(with large nq and topk)multiple times, and search/add vectors still work
Method:
a. coroutine or multi-processing, to simulate network crashing
b. connect, search and disconnect, repeating many times
c. connect and search, add vectors
Expected: Successfully searched back, successfully added
'''
pass
def _test_thread_safe_with_one_connection_shared_in_multi_threads(self):
'''
Target: test 1 connection thread safe
Method: 1 connection shared in multi-threads, all adding vectors, or other things
Expected: Functional as one thread
'''
pass
class TestConnectIPInvalid(object):
"""
Test connect server with invalid ip
"""
@pytest.fixture(
scope="function",
params=gen_invalid_ips()
)
def get_invalid_ip(self, request):
yield request.param
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_with_invalid_ip(self, args, get_invalid_ip):
milvus = Milvus()
ip = get_invalid_ip
with pytest.raises(Exception) as e:
milvus.connect(host=ip, port=args["port"], timeout=1)
assert not milvus.connected()
class TestConnectPortInvalid(object):
"""
Test connect server with invalid ip
"""
@pytest.fixture(
scope="function",
params=gen_invalid_ports()
)
def get_invalid_port(self, request):
yield request.param
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_with_invalid_port(self, args, get_invalid_port):
'''
target: test ip:port connect with invalid port value
method: set port in gen_invalid_ports
expected: connected is False
'''
milvus = Milvus()
port = get_invalid_port
with pytest.raises(Exception) as e:
milvus.connect(host=args["ip"], port=port, timeout=1)
assert not milvus.connected()
class TestConnectURIInvalid(object):
"""
Test connect server with invalid uri
"""
@pytest.fixture(
scope="function",
params=gen_invalid_uris()
)
def get_invalid_uri(self, request):
yield request.param
@pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_with_invalid_uri(self, get_invalid_uri):
'''
target: test uri connect with invalid uri value
method: set port in gen_invalid_uris
expected: connected is False
'''
milvus = Milvus()
uri_value = get_invalid_uri
with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value, timeout=1)
assert not milvus.connected()

View File

@ -0,0 +1,419 @@
import time
import random
import pdb
import logging
import threading
from builtins import Exception
from multiprocessing import Pool, Process
import pytest
from milvus import Milvus, IndexType
from utils import *
dim = 128
index_file_size = 10
table_id = "test_delete"
DELETE_TIMEOUT = 60
vectors = gen_vectors(100, dim)
class TestDeleteVectorsBase:
"""
generate invalid query range params
"""
@pytest.fixture(
scope="function",
params=[
(get_current_day(), get_current_day()),
(get_last_day(1), get_last_day(1)),
(get_next_day(1), get_next_day(1))
]
)
def get_invalid_range(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_invalid_range(self, connect, table, get_invalid_range):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with invalid date params
expected: return code 0
'''
start_date = get_invalid_range[0]
end_date = get_invalid_range[1]
status, ids = connect.add_vectors(table, vectors)
status = connect.delete_vectors_by_range(table, start_date, end_date)
assert not status.OK()
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_last_day(1)),
(get_last_day(2), get_current_day()),
(get_next_day(1), get_next_day(2))
]
)
def get_valid_range_no_result(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range_no_result(self, connect, table, get_valid_range_no_result):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params
expected: return code 0
'''
start_date = get_valid_range_no_result[0]
end_date = get_valid_range_no_result[1]
status, ids = connect.add_vectors(table, vectors)
time.sleep(2)
status = connect.delete_vectors_by_range(table, start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(table)
assert result == 100
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_next_day(2)),
(get_current_day(), get_next_day(2)),
]
)
def get_valid_range(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range(self, connect, table, get_valid_range):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params
expected: return code 0
'''
start_date = get_valid_range[0]
end_date = get_valid_range[1]
status, ids = connect.add_vectors(table, vectors)
time.sleep(2)
status = connect.delete_vectors_by_range(table, start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(table)
assert result == 0
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range_index_created(self, connect, table, get_index_params):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params
expected: return code 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
logging.getLogger().info(status)
logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date))
status = connect.delete_vectors_by_range(table, start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(table)
assert result == 0
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_no_data(self, connect, table):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params, and no data in db
expected: return code 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
# status, ids = connect.add_vectors(table, vectors)
status = connect.delete_vectors_by_range(table, start_date, end_date)
assert status.OK()
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_table_not_existed(self, connect):
'''
target: test delete vectors, table not existed in db
method: call `delete_vectors_by_range`, with table not existed
expected: return code not 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
table_name = gen_unique_str("not_existed_table")
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
assert not status.OK()
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_table_None(self, connect, table):
'''
target: test delete vectors, table set Nope
method: call `delete_vectors_by_range`, with table value is None
expected: return code not 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
table_name = None
with pytest.raises(Exception) as e:
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range):
'''
target: test delete vectors is correct or not with multiple tables of L2
method: create 50 tables and add vectors into them , then delete vectors
in valid range
expected: return code 0
'''
nq = 100
vectors = gen_vectors(nq, dim)
table_list = []
for i in range(50):
table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
time.sleep(2)
start_date = get_valid_range[0]
end_date = get_valid_range[1]
for i in range(50):
status = connect.delete_vectors_by_range(table_list[i], start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(table_list[i])
assert result == 0
class TestDeleteVectorsIP:
"""
generate invalid query range params
"""
@pytest.fixture(
scope="function",
params=[
(get_current_day(), get_current_day()),
(get_last_day(1), get_last_day(1)),
(get_next_day(1), get_next_day(1))
]
)
def get_invalid_range(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_invalid_range(self, connect, ip_table, get_invalid_range):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with invalid date params
expected: return code 0
'''
start_date = get_invalid_range[0]
end_date = get_invalid_range[1]
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
assert not status.OK()
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_last_day(1)),
(get_last_day(2), get_current_day()),
(get_next_day(1), get_next_day(2))
]
)
def get_valid_range_no_result(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range_no_result(self, connect, ip_table, get_valid_range_no_result):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params
expected: return code 0
'''
start_date = get_valid_range_no_result[0]
end_date = get_valid_range_no_result[1]
status, ids = connect.add_vectors(ip_table, vectors)
time.sleep(2)
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(ip_table)
assert result == 100
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_next_day(2)),
(get_current_day(), get_next_day(2)),
]
)
def get_valid_range(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range(self, connect, ip_table, get_valid_range):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params
expected: return code 0
'''
start_date = get_valid_range[0]
end_date = get_valid_range[1]
status, ids = connect.add_vectors(ip_table, vectors)
time.sleep(2)
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(ip_table)
assert result == 0
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range_index_created(self, connect, ip_table, get_index_params):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params
expected: return code 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
logging.getLogger().info(status)
logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date))
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(ip_table)
assert result == 0
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_no_data(self, connect, ip_table):
'''
target: test delete vectors, no index created
method: call `delete_vectors_by_range`, with valid date params, and no data in db
expected: return code 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
# status, ids = connect.add_vectors(table, vectors)
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
assert status.OK()
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_table_None(self, connect, ip_table):
'''
target: test delete vectors, table set Nope
method: call `delete_vectors_by_range`, with table value is None
expected: return code not 0
'''
start_date = get_current_day()
end_date = get_next_day(2)
table_name = None
with pytest.raises(Exception) as e:
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range):
'''
target: test delete vectors is correct or not with multiple tables of IP
method: create 50 tables and add vectors into them , then delete vectors
in valid range
expected: return code 0
'''
nq = 100
vectors = gen_vectors(nq, dim)
table_list = []
for i in range(50):
table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
time.sleep(2)
start_date = get_valid_range[0]
end_date = get_valid_range[1]
for i in range(50):
status = connect.delete_vectors_by_range(table_list[i], start_date, end_date)
assert status.OK()
status, result = connect.get_table_row_count(table_list[i])
assert result == 0
class TestDeleteVectorsParamsInvalid:
"""
Test search table with invalid table names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_table_names()
)
def get_table_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_delete_vectors_table_invalid_name(self, connect, get_table_name):
'''
'''
start_date = get_current_day()
end_date = get_next_day(2)
table_name = get_table_name
logging.getLogger().info(table_name)
top_k = 1
nprobe = 1
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
assert not status.OK()
"""
Test search table with invalid query ranges
"""
@pytest.fixture(
scope="function",
params=gen_invalid_query_ranges()
)
def get_query_ranges(self, request):
yield request.param
@pytest.mark.timeout(DELETE_TIMEOUT)
def test_delete_vectors_range_invalid(self, connect, table, get_query_ranges):
'''
target: test search fuction, with the wrong query_range
method: search with query_range
expected: raise an error, and the connection is normal
'''
start_date = get_query_ranges[0][0]
end_date = get_query_ranges[0][1]
status, ids = connect.add_vectors(table, vectors)
logging.getLogger().info(get_query_ranges)
with pytest.raises(Exception) as e:
status = connect.delete_vectors_by_range(table, start_date, end_date)

View File

@ -0,0 +1,966 @@
"""
For testing index operations, including `create_index`, `describe_index` and `drop_index` interfaces
"""
import logging
import pytest
import time
import pdb
import threading
from multiprocessing import Pool, Process
import numpy
from milvus import Milvus, IndexType, MetricType
from utils import *
nb = 100000
dim = 128
index_file_size = 10
vectors = gen_vectors(nb, dim)
vectors /= numpy.linalg.norm(vectors)
vectors = vectors.tolist()
BUILD_TIMEOUT = 60
nprobe = 1
class TestIndexBase:
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_simple_index_params()
)
def get_simple_index_params(self, request):
yield request.param
"""
******************************************************************
The following cases are used to test `create_index` function
******************************************************************
"""
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index(self, connect, table, get_index_params):
'''
target: test create index interface
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
assert status.OK()
@pytest.mark.level(2)
def test_create_index_without_connect(self, dis_connect, table):
'''
target: test create index without connection
method: create table and add vectors in it, check if added successfully
expected: raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.create_index(table, random.choice(gen_index_params()))
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_search_with_query_vectors(self, connect, table, get_index_params):
'''
target: test create index interface, search with more query vectors
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
logging.getLogger().info(connect.describe_index(table))
query_vecs = [vectors[0], vectors[1], vectors[2]]
top_k = 5
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
logging.getLogger().info(result)
# TODO: enable
@pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.level(2)
def _test_create_index_multiprocessing(self, connect, table, args):
'''
target: test create index interface with multiprocess
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
status, ids = connect.add_vectors(table, vectors)
def build(connect):
status = connect.create_index(table)
assert status.OK()
process_num = 8
processes = []
uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num):
m = Milvus()
m.connect(uri=uri)
p = Process(target=build, args=(m,))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
query_vec = [vectors[0]]
top_k = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
assert len(result) == 1
assert len(result[0]) == top_k
assert result[0][0].distance == 0.0
# TODO: enable
@pytest.mark.timeout(BUILD_TIMEOUT)
def _test_create_index_multiprocessing_multitable(self, connect, args):
'''
target: test create index interface with multiprocess
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
process_num = 8
loop_num = 8
processes = []
table = []
j = 0
while j < (process_num*loop_num):
table_name = gen_unique_str("test_create_index_multiprocessing")
table.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_type': IndexType.FLAT,
'store_raw_vector': False}
connect.create_table(param)
j = j + 1
def create_index():
i = 0
while i < loop_num:
# assert connect.has_table(table[ids*process_num+i])
status, ids = connect.add_vectors(table[ids*process_num+i], vectors)
status = connect.create_index(table[ids*process_num+i])
assert status.OK()
query_vec = [vectors[0]]
top_k = 1
status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec)
assert len(result) == 1
assert len(result[0]) == top_k
assert result[0][0].distance == 0.0
i = i + 1
uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num):
m = Milvus()
m.connect(uri=uri)
ids = i
p = Process(target=create_index, args=(m,ids))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
def test_create_index_table_not_existed(self, connect):
'''
target: test create index interface when table name not existed
method: create table and add vectors in it, create index with an random table_name
, make sure the table name not in index
expected: return code not equals to 0, create index failed
'''
table_name = gen_unique_str(self.__class__.__name__)
status = connect.create_index(table_name, random.choice(gen_index_params()))
assert not status.OK()
def test_create_index_table_None(self, connect):
'''
target: test create index interface when table name is None
method: create table and add vectors in it, create index with an table_name: None
expected: return code not equals to 0, create index failed
'''
table_name = None
with pytest.raises(Exception) as e:
status = connect.create_index(table_name, random.choice(gen_index_params()))
def test_create_index_no_vectors(self, connect, table):
'''
target: test create index interface when there is no vectors in table
method: create table and add no vectors in it, and then create index
expected: return code equals to 0
'''
status = connect.create_index(table, random.choice(gen_index_params()))
assert status.OK()
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_no_vectors_then_add_vectors(self, connect, table):
'''
target: test create index interface when there is no vectors in table, and does not affect the subsequent process
method: create table and add no vectors in it, and then create index, add vectors in it
expected: return code equals to 0
'''
status = connect.create_index(table, random.choice(gen_index_params()))
status, ids = connect.add_vectors(table, vectors)
assert status.OK()
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_same_index_repeatedly(self, connect, table):
'''
target: check if index can be created repeatedly, with the same create_index params
method: create index after index have been built
expected: return code success, and search ok
'''
status, ids = connect.add_vectors(table, vectors)
index_params = random.choice(gen_index_params())
# index_params = get_index_params
status = connect.create_index(table, index_params)
status = connect.create_index(table, index_params)
assert status.OK()
query_vec = [vectors[0]]
top_k = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
assert len(result) == 1
assert len(result[0]) == top_k
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_different_index_repeatedly(self, connect, table):
'''
target: check if index can be created repeatedly, with the different create_index params
method: create another index with different index_params after index have been built
expected: return code 0, and describe index result equals with the second index params
'''
status, ids = connect.add_vectors(table, vectors)
index_params = random.sample(gen_index_params(), 2)
logging.getLogger().info(index_params)
status = connect.create_index(table, index_params[0])
status = connect.create_index(table, index_params[1])
assert status.OK()
status, result = connect.describe_index(table)
assert result._nlist == index_params[1]["nlist"]
assert result._table_name == table
assert result._index_type == index_params[1]["index_type"]
"""
******************************************************************
The following cases are used to test `describe_index` function
******************************************************************
"""
def test_describe_index(self, connect, table, get_index_params):
'''
target: test describe index interface
method: create table and add vectors in it, create index, call describe index
expected: return code 0, and index instructure
'''
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
status, result = connect.describe_index(table)
logging.getLogger().info(result)
assert result._nlist == index_params["nlist"]
assert result._table_name == table
assert result._index_type == index_params["index_type"]
def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params):
'''
target: test create, describe and drop index interface with multiple tables of L2
method: create tables and add vectors in it, create index, call describe index
expected: return code 0, and index instructure
'''
nq = 100
vectors = gen_vectors(nq, dim)
table_list = []
for i in range(10):
table_name = gen_unique_str('test_create_index_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
index_params = get_simple_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
status = connect.create_index(table_name, index_params)
assert status.OK()
for i in range(10):
status, result = connect.describe_index(table_list[i])
logging.getLogger().info(result)
assert result._nlist == index_params["nlist"]
assert result._table_name == table_list[i]
assert result._index_type == index_params["index_type"]
for i in range(10):
status = connect.drop_index(table_list[i])
assert status.OK()
status, result = connect.describe_index(table_list[i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[i]
assert result._index_type == IndexType.FLAT
@pytest.mark.level(2)
def test_describe_index_without_connect(self, dis_connect, table):
'''
target: test describe index without connection
method: describe index, and check if describe successfully
expected: raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.describe_index(table)
def test_describe_index_table_not_existed(self, connect):
'''
target: test describe index interface when table name not existed
method: create table and add vectors in it, create index with an random table_name
, make sure the table name not in index
expected: return code not equals to 0, describe index failed
'''
table_name = gen_unique_str(self.__class__.__name__)
status, result = connect.describe_index(table_name)
assert not status.OK()
def test_describe_index_table_None(self, connect):
'''
target: test describe index interface when table name is None
method: create table and add vectors in it, create index with an table_name: None
expected: return code not equals to 0, describe index failed
'''
table_name = None
with pytest.raises(Exception) as e:
status = connect.describe_index(table_name)
def test_describe_index_not_create(self, connect, table):
'''
target: test describe index interface when index not created
method: create table and add vectors in it, create index with an random table_name
, make sure the table name not in index
expected: return code not equals to 0, describe index failed
'''
status, ids = connect.add_vectors(table, vectors)
status, result = connect.describe_index(table)
logging.getLogger().info(result)
assert status.OK()
# assert result._nlist == index_params["nlist"]
# assert result._table_name == table
# assert result._index_type == index_params["index_type"]
"""
******************************************************************
The following cases are used to test `drop_index` function
******************************************************************
"""
def test_drop_index(self, connect, table, get_index_params):
'''
target: test drop index interface
method: create table and add vectors in it, create index, call drop index
expected: return code 0, and default index param
'''
index_params = get_index_params
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
status = connect.drop_index(table)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table
assert result._index_type == IndexType.FLAT
def test_drop_index_repeatly(self, connect, table, get_simple_index_params):
'''
target: test drop index repeatly
method: create index, call drop index, and drop again
expected: return code 0
'''
index_params = get_simple_index_params
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
status = connect.drop_index(table)
assert status.OK()
status = connect.drop_index(table)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table
assert result._index_type == IndexType.FLAT
@pytest.mark.level(2)
def test_drop_index_without_connect(self, dis_connect, table):
'''
target: test drop index without connection
method: drop index, and check if drop successfully
expected: raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.drop_index(table)
def test_drop_index_table_not_existed(self, connect):
'''
target: test drop index interface when table name not existed
method: create table and add vectors in it, create index with an random table_name
, make sure the table name not in index, and then drop it
expected: return code not equals to 0, drop index failed
'''
table_name = gen_unique_str(self.__class__.__name__)
status = connect.drop_index(table_name)
assert not status.OK()
def test_drop_index_table_None(self, connect):
'''
target: test drop index interface when table name is None
method: create table and add vectors in it, create index with an table_name: None
expected: return code not equals to 0, drop index failed
'''
table_name = None
with pytest.raises(Exception) as e:
status = connect.drop_index(table_name)
def test_drop_index_table_not_create(self, connect, table):
'''
target: test drop index interface when index not created
method: create table and add vectors in it, create index
expected: return code not equals to 0, drop index failed
'''
index_params = random.choice(gen_index_params())
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table, vectors)
status, result = connect.describe_index(table)
logging.getLogger().info(result)
# no create index
status = connect.drop_index(table)
logging.getLogger().info(status)
assert status.OK()
def test_create_drop_index_repeatly(self, connect, table, get_simple_index_params):
'''
target: test create / drop index repeatly, use the same index params
method: create index, drop index, four times
expected: return code 0
'''
index_params = get_simple_index_params
status, ids = connect.add_vectors(table, vectors)
for i in range(2):
status = connect.create_index(table, index_params)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
status = connect.drop_index(table)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table
assert result._index_type == IndexType.FLAT
def test_create_drop_index_repeatly_different_index_params(self, connect, table):
'''
target: test create / drop index repeatly, use the different index params
method: create index, drop index, four times, each tme use different index_params to create index
expected: return code 0
'''
index_params = random.sample(gen_index_params(), 2)
status, ids = connect.add_vectors(table, vectors)
for i in range(2):
status = connect.create_index(table, index_params[i])
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
status = connect.drop_index(table)
assert status.OK()
status, result = connect.describe_index(table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table
assert result._index_type == IndexType.FLAT
class TestIndexIP:
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
@pytest.fixture(
scope="function",
params=gen_simple_index_params()
)
def get_simple_index_params(self, request):
yield request.param
"""
******************************************************************
The following cases are used to test `create_index` function
******************************************************************
"""
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index(self, connect, ip_table, get_index_params):
'''
target: test create index interface
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
assert status.OK()
@pytest.mark.level(2)
def test_create_index_without_connect(self, dis_connect, ip_table):
'''
target: test create index without connection
method: create table and add vectors in it, check if added successfully
expected: raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.create_index(ip_table, random.choice(gen_index_params()))
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_search_with_query_vectors(self, connect, ip_table, get_index_params):
'''
target: test create index interface, search with more query vectors
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
logging.getLogger().info(connect.describe_index(ip_table))
query_vecs = [vectors[0], vectors[1], vectors[2]]
top_k = 5
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
# logging.getLogger().info(result)
# TODO: enable
@pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.level(2)
def _test_create_index_multiprocessing(self, connect, ip_table, args):
'''
target: test create index interface with multiprocess
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
status, ids = connect.add_vectors(ip_table, vectors)
def build(connect):
status = connect.create_index(ip_table)
assert status.OK()
process_num = 8
processes = []
uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num):
m = Milvus()
m.connect(uri=uri)
p = Process(target=build, args=(m,))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
query_vec = [vectors[0]]
top_k = 1
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
assert len(result) == 1
assert len(result[0]) == top_k
assert result[0][0].distance == 0.0
# TODO: enable
@pytest.mark.timeout(BUILD_TIMEOUT)
def _test_create_index_multiprocessing_multitable(self, connect, args):
'''
target: test create index interface with multiprocess
method: create table and add vectors in it, create index
expected: return code equals to 0, and search success
'''
process_num = 8
loop_num = 8
processes = []
table = []
j = 0
while j < (process_num*loop_num):
table_name = gen_unique_str("test_create_index_multiprocessing")
table.append(table_name)
param = {'table_name': table_name,
'dimension': dim}
connect.create_table(param)
j = j + 1
def create_index():
i = 0
while i < loop_num:
# assert connect.has_table(table[ids*process_num+i])
status, ids = connect.add_vectors(table[ids*process_num+i], vectors)
status = connect.create_index(table[ids*process_num+i])
assert status.OK()
query_vec = [vectors[0]]
top_k = 1
status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec)
assert len(result) == 1
assert len(result[0]) == top_k
assert result[0][0].distance == 0.0
i = i + 1
uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num):
m = Milvus()
m.connect(uri=uri)
ids = i
p = Process(target=create_index, args=(m,ids))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
def test_create_index_no_vectors(self, connect, ip_table):
'''
target: test create index interface when there is no vectors in table
method: create table and add no vectors in it, and then create index
expected: return code equals to 0
'''
status = connect.create_index(ip_table, random.choice(gen_index_params()))
assert status.OK()
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_index_no_vectors_then_add_vectors(self, connect, ip_table):
'''
target: test create index interface when there is no vectors in table, and does not affect the subsequent process
method: create table and add no vectors in it, and then create index, add vectors in it
expected: return code equals to 0
'''
status = connect.create_index(ip_table, random.choice(gen_index_params()))
status, ids = connect.add_vectors(ip_table, vectors)
assert status.OK()
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_same_index_repeatedly(self, connect, ip_table):
'''
target: check if index can be created repeatedly, with the same create_index params
method: create index after index have been built
expected: return code success, and search ok
'''
status, ids = connect.add_vectors(ip_table, vectors)
index_params = random.choice(gen_index_params())
# index_params = get_index_params
status = connect.create_index(ip_table, index_params)
status = connect.create_index(ip_table, index_params)
assert status.OK()
query_vec = [vectors[0]]
top_k = 1
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
assert len(result) == 1
assert len(result[0]) == top_k
@pytest.mark.timeout(BUILD_TIMEOUT)
def test_create_different_index_repeatedly(self, connect, ip_table):
'''
target: check if index can be created repeatedly, with the different create_index params
method: create another index with different index_params after index have been built
expected: return code 0, and describe index result equals with the second index params
'''
status, ids = connect.add_vectors(ip_table, vectors)
index_params = random.sample(gen_index_params(), 2)
logging.getLogger().info(index_params)
status = connect.create_index(ip_table, index_params[0])
status = connect.create_index(ip_table, index_params[1])
assert status.OK()
status, result = connect.describe_index(ip_table)
assert result._nlist == index_params[1]["nlist"]
assert result._table_name == ip_table
assert result._index_type == index_params[1]["index_type"]
"""
******************************************************************
The following cases are used to test `describe_index` function
******************************************************************
"""
def test_describe_index(self, connect, ip_table, get_index_params):
'''
target: test describe index interface
method: create table and add vectors in it, create index, call describe index
expected: return code 0, and index instructure
'''
index_params = get_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
assert result._nlist == index_params["nlist"]
assert result._table_name == ip_table
assert result._index_type == index_params["index_type"]
def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params):
'''
target: test create, describe and drop index interface with multiple tables of IP
method: create tables and add vectors in it, create index, call describe index
expected: return code 0, and index instructure
'''
nq = 100
vectors = gen_vectors(nq, dim)
table_list = []
for i in range(10):
table_name = gen_unique_str('test_create_index_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
index_params = get_simple_index_params
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
status = connect.create_index(table_name, index_params)
assert status.OK()
for i in range(10):
status, result = connect.describe_index(table_list[i])
logging.getLogger().info(result)
assert result._nlist == index_params["nlist"]
assert result._table_name == table_list[i]
assert result._index_type == index_params["index_type"]
for i in range(10):
status = connect.drop_index(table_list[i])
assert status.OK()
status, result = connect.describe_index(table_list[i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[i]
assert result._index_type == IndexType.FLAT
@pytest.mark.level(2)
def test_describe_index_without_connect(self, dis_connect, ip_table):
'''
target: test describe index without connection
method: describe index, and check if describe successfully
expected: raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.describe_index(ip_table)
def test_describe_index_not_create(self, connect, ip_table):
'''
target: test describe index interface when index not created
method: create table and add vectors in it, create index with an random table_name
, make sure the table name not in index
expected: return code not equals to 0, describe index failed
'''
status, ids = connect.add_vectors(ip_table, vectors)
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
assert status.OK()
# assert result._nlist == index_params["nlist"]
# assert result._table_name == table
# assert result._index_type == index_params["index_type"]
"""
******************************************************************
The following cases are used to test `drop_index` function
******************************************************************
"""
def test_drop_index(self, connect, ip_table, get_index_params):
'''
target: test drop index interface
method: create table and add vectors in it, create index, call drop index
expected: return code 0, and default index param
'''
index_params = get_index_params
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
status = connect.drop_index(ip_table)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == ip_table
assert result._index_type == IndexType.FLAT
def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params):
'''
target: test drop index repeatly
method: create index, call drop index, and drop again
expected: return code 0
'''
index_params = get_simple_index_params
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
status = connect.drop_index(ip_table)
assert status.OK()
status = connect.drop_index(ip_table)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == ip_table
assert result._index_type == IndexType.FLAT
@pytest.mark.level(2)
def test_drop_index_without_connect(self, dis_connect, ip_table):
'''
target: test drop index without connection
method: drop index, and check if drop successfully
expected: raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.drop_index(ip_table, random.choice(gen_index_params()))
def test_drop_index_table_not_create(self, connect, ip_table):
'''
target: test drop index interface when index not created
method: create table and add vectors in it, create index
expected: return code not equals to 0, drop index failed
'''
index_params = random.choice(gen_index_params())
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(ip_table, vectors)
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
# no create index
status = connect.drop_index(ip_table)
logging.getLogger().info(status)
assert status.OK()
def test_create_drop_index_repeatly(self, connect, ip_table, get_simple_index_params):
'''
target: test create / drop index repeatly, use the same index params
method: create index, drop index, four times
expected: return code 0
'''
index_params = get_simple_index_params
status, ids = connect.add_vectors(ip_table, vectors)
for i in range(2):
status = connect.create_index(ip_table, index_params)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
status = connect.drop_index(ip_table)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == ip_table
assert result._index_type == IndexType.FLAT
def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table):
'''
target: test create / drop index repeatly, use the different index params
method: create index, drop index, four times, each tme use different index_params to create index
expected: return code 0
'''
index_params = random.sample(gen_index_params(), 2)
status, ids = connect.add_vectors(ip_table, vectors)
for i in range(2):
status = connect.create_index(ip_table, index_params[i])
assert status.OK()
status, result = connect.describe_index(ip_table)
assert result._nlist == index_params[i]["nlist"]
assert result._table_name == ip_table
assert result._index_type == index_params[i]["index_type"]
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
status = connect.drop_index(ip_table)
assert status.OK()
status, result = connect.describe_index(ip_table)
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == ip_table
assert result._index_type == IndexType.FLAT
class TestIndexTableInvalid(object):
"""
Test create / describe / drop index interfaces with invalid table names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_table_names()
)
def get_table_name(self, request):
yield request.param
# @pytest.mark.level(1)
def test_create_index_with_invalid_tablename(self, connect, get_table_name):
table_name = get_table_name
status = connect.create_index(table_name, random.choice(gen_index_params()))
assert not status.OK()
# @pytest.mark.level(1)
def test_describe_index_with_invalid_tablename(self, connect, get_table_name):
table_name = get_table_name
status, result = connect.describe_index(table_name)
assert not status.OK()
# @pytest.mark.level(1)
def test_drop_index_with_invalid_tablename(self, connect, get_table_name):
table_name = get_table_name
status = connect.drop_index(table_name)
assert not status.OK()
class TestCreateIndexParamsInvalid(object):
"""
Test Building index with invalid table names, table names not in db
"""
@pytest.fixture(
scope="function",
params=gen_invalid_index_params()
)
def get_index_params(self, request):
yield request.param
@pytest.mark.level(2)
def test_create_index_with_invalid_index_params(self, connect, table, get_index_params):
index_params = get_index_params
index_type = index_params["index_type"]
nlist = index_params["nlist"]
logging.getLogger().info(index_params)
status, ids = connect.add_vectors(table, vectors)
# if not isinstance(index_type, int) or not isinstance(nlist, int):
with pytest.raises(Exception) as e:
status = connect.create_index(table, index_params)
# else:
# status = connect.create_index(table, index_params)
# assert not status.OK()

View File

@ -0,0 +1,180 @@
import pdb
import copy
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
import numpy
from milvus import Milvus, IndexType, MetricType
from utils import *
dim = 128
index_file_size = 10
table_id = "test_mix"
add_interval_time = 2
vectors = gen_vectors(100000, dim)
vectors /= numpy.linalg.norm(vectors)
vectors = vectors.tolist()
top_k = 1
nprobe = 1
epsilon = 0.0001
index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
class TestMixBase:
# TODO: enable
def _test_search_during_createIndex(self, args):
loops = 100000
table = "test_search_during_createIndex"
query_vecs = [vectors[0], vectors[1]]
uri = "tcp://%s:%s" % (args["ip"], args["port"])
id_0 = 0; id_1 = 0
milvus_instance = Milvus()
milvus_instance.connect(uri=uri)
milvus_instance.create_table({'table_name': table,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2})
for i in range(10):
status, ids = milvus_instance.add_vectors(table, vectors)
# logging.getLogger().info(ids)
if i == 0:
id_0 = ids[0]; id_1 = ids[1]
def create_index(milvus_instance):
logging.getLogger().info("In create index")
status = milvus_instance.create_index(table, index_params)
logging.getLogger().info(status)
status, result = milvus_instance.describe_index(table)
logging.getLogger().info(result)
def add_vectors(milvus_instance):
logging.getLogger().info("In add vectors")
status, ids = milvus_instance.add_vectors(table, vectors)
logging.getLogger().info(status)
def search(milvus_instance):
for i in range(loops):
status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs)
logging.getLogger().info(status)
assert result[0][0].id == id_0
assert result[1][0].id == id_1
milvus_instance = Milvus()
milvus_instance.connect(uri=uri)
p_search = Process(target=search, args=(milvus_instance, ))
p_search.start()
milvus_instance = Milvus()
milvus_instance.connect(uri=uri)
p_create = Process(target=add_vectors, args=(milvus_instance, ))
p_create.start()
p_create.join()
def test_mix_multi_tables(self, connect):
'''
target: test functions with multiple tables of different metric_types and index_types
method: create 60 tables which 30 are L2 and the other are IP, add vectors into them
and test describe index and search
expected: status ok
'''
nq = 10000
vectors = gen_vectors(nq, dim)
table_list = []
idx = []
#create table and add vectors
for i in range(30):
table_name = gen_unique_str('test_mix_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
assert status.OK()
for i in range(30):
table_name = gen_unique_str('test_mix_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
assert status.OK()
time.sleep(2)
#create index
for i in range(10):
index_params = {'index_type': IndexType.FLAT, 'nlist': 16384}
status = connect.create_index(table_list[i], index_params)
assert status.OK()
status = connect.create_index(table_list[30 + i], index_params)
assert status.OK()
index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
status = connect.create_index(table_list[10 + i], index_params)
assert status.OK()
status = connect.create_index(table_list[40 + i], index_params)
assert status.OK()
index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': 16384}
status = connect.create_index(table_list[20 + i], index_params)
assert status.OK()
status = connect.create_index(table_list[50 + i], index_params)
assert status.OK()
#describe index
for i in range(10):
status, result = connect.describe_index(table_list[i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[i]
assert result._index_type == IndexType.FLAT
status, result = connect.describe_index(table_list[10 + i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[10 + i]
assert result._index_type == IndexType.IVFLAT
status, result = connect.describe_index(table_list[20 + i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[20 + i]
assert result._index_type == IndexType.IVF_SQ8
status, result = connect.describe_index(table_list[30 + i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[30 + i]
assert result._index_type == IndexType.FLAT
status, result = connect.describe_index(table_list[40 + i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[40 + i]
assert result._index_type == IndexType.IVFLAT
status, result = connect.describe_index(table_list[50 + i])
logging.getLogger().info(result)
assert result._nlist == 16384
assert result._table_name == table_list[50 + i]
assert result._index_type == IndexType.IVF_SQ8
#search
query_vecs = [vectors[0], vectors[10], vectors[20]]
for i in range(60):
table = table_list[i]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for j in range(len(query_vecs)):
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
assert check_result(result[j], idx[3 * i + j])
def check_result(result, id):
if len(result) >= 5:
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
else:
return id in (i.id for i in result)

View File

@ -0,0 +1,77 @@
import logging
import pytest
__version__ = '0.5.0'
class TestPing:
def test_server_version(self, connect):
'''
target: test get the server version
method: call the server_version method after connected
expected: version should be the pymilvus version
'''
status, res = connect.server_version()
assert res == __version__
def test_server_status(self, connect):
'''
target: test get the server status
method: call the server_status method after connected
expected: status returned should be ok
'''
status, msg = connect.server_status()
assert status.OK()
def _test_server_cmd_with_params_version(self, connect):
'''
target: test cmd: version
method: cmd = "version" ...
expected: when cmd = 'version', return version of server;
'''
cmd = "version"
status, msg = connect.cmd(cmd)
logging.getLogger().info(status)
logging.getLogger().info(msg)
assert status.OK()
assert msg == __version__
def _test_server_cmd_with_params_others(self, connect):
'''
target: test cmd: lalala
method: cmd = "lalala" ...
expected: when cmd = 'version', return version of server;
'''
cmd = "rm -rf test"
status, msg = connect.cmd(cmd)
logging.getLogger().info(status)
logging.getLogger().info(msg)
assert status.OK()
# assert msg == __version__
def test_connected(self, connect):
assert connect.connected()
class TestPingDisconnect:
def test_server_version(self, dis_connect):
'''
target: test get the server version, after disconnect
method: call the server_version method after connected
expected: version should not be the pymilvus version
'''
res = None
with pytest.raises(Exception) as e:
status, res = connect.server_version()
assert res is None
def test_server_status(self, dis_connect):
'''
target: test get the server status, after disconnect
method: call the server_status method after connected
expected: status returned should be not ok
'''
status = None
with pytest.raises(Exception) as e:
status, msg = connect.server_status()
assert status is None

View File

@ -0,0 +1,650 @@
import pdb
import copy
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
import numpy
from milvus import Milvus, IndexType, MetricType
from utils import *
dim = 128
table_id = "test_search"
add_interval_time = 2
vectors = gen_vectors(100, dim)
# vectors /= numpy.linalg.norm(vectors)
# vectors = vectors.tolist()
nrpobe = 1
epsilon = 0.001
class TestSearchBase:
def init_data(self, connect, table, nb=100):
'''
Generate vectors and add it in table, before search vectors
'''
global vectors
if nb == 100:
add_vectors = vectors
else:
add_vectors = gen_vectors(nb, dim)
# add_vectors /= numpy.linalg.norm(add_vectors)
# add_vectors = add_vectors.tolist()
status, ids = connect.add_vectors(table, add_vectors)
sleep(add_interval_time)
return add_vectors, ids
"""
generate valid create_index params
"""
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
"""
generate top-k params
"""
@pytest.fixture(
scope="function",
params=[1, 99, 101, 1024, 2048, 2049]
)
def get_top_k(self, request):
yield request.param
def test_search_top_k_flat_index(self, connect, table, get_top_k):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
vectors, ids = self.init_data(connect, table)
query_vec = [vectors[0]]
top_k = get_top_k
nprobe = 1
status, result = connect.search_vectors(table, top_k, nrpobe, query_vec)
if top_k <= 2048:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance <= epsilon
assert check_result(result[0], ids[0])
else:
assert not status.OK()
def test_search_l2_index_params(self, connect, table, get_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_index_params
logging.getLogger().info(index_params)
vectors, ids = self.init_data(connect, table)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nrpobe, query_vec)
logging.getLogger().info(result)
if top_k <= 1024:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance <= epsilon
else:
assert not status.OK()
def test_search_ip_index_params(self, connect, ip_table, get_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_index_params
logging.getLogger().info(index_params)
vectors, ids = self.init_data(connect, ip_table)
status = connect.create_index(ip_table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(ip_table, top_k, nrpobe, query_vec)
logging.getLogger().info(result)
if top_k <= 1024:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance)
else:
assert not status.OK()
@pytest.mark.level(2)
def test_search_vectors_without_connect(self, dis_connect, table):
'''
target: test search vectors without connection
method: use dis connected instance, call search method and check if search successfully
expected: raise exception
'''
query_vectors = [vectors[0]]
top_k = 1
nprobe = 1
with pytest.raises(Exception) as e:
status, ids = dis_connect.search_vectors(table, top_k, nprobe, query_vectors)
def test_search_table_name_not_existed(self, connect, table):
'''
target: search table not existed
method: search with the random table_name, which is not in db
expected: status not ok
'''
table_name = gen_unique_str("not_existed_table")
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
assert not status.OK()
def test_search_table_name_None(self, connect, table):
'''
target: search table that table name is None
method: search with the table_name: None
expected: status not ok
'''
table_name = None
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
def test_search_top_k_query_records(self, connect, table):
'''
target: test search fuction, with search params: query_records
method: search with the given query_records, which are subarrays of the inserted vectors
expected: status ok and the returned vectors should be query_records
'''
top_k = 10
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0],vectors[55],vectors[99]]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for i in range(len(query_vecs)):
assert len(result[i]) == top_k
assert result[i][0].distance <= epsilon
"""
generate invalid query range params
"""
@pytest.fixture(
scope="function",
params=[
(get_current_day(), get_current_day()),
(get_last_day(1), get_last_day(1)),
(get_next_day(1), get_next_day(1))
]
)
def get_invalid_range(self, request):
yield request.param
def test_search_invalid_query_ranges(self, connect, table, get_invalid_range):
'''
target: search table with query ranges
method: search with the same query ranges
expected: status not ok
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0]]
query_ranges = [get_invalid_range]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
assert not status.OK()
assert len(result) == 0
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_last_day(1)),
(get_last_day(2), get_current_day()),
(get_next_day(1), get_next_day(2))
]
)
def get_valid_range_no_result(self, request):
yield request.param
def test_search_valid_query_ranges_no_result(self, connect, table, get_valid_range_no_result):
'''
target: search table with normal query ranges, but no data in db
method: search with query ranges (low, low)
expected: length of result is 0
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0]]
query_ranges = [get_valid_range_no_result]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
assert status.OK()
assert len(result) == 0
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_next_day(2)),
(get_current_day(), get_next_day(2)),
]
)
def get_valid_range(self, request):
yield request.param
def test_search_valid_query_ranges(self, connect, table, get_valid_range):
'''
target: search table with normal query ranges, but no data in db
method: search with query ranges (low, normal)
expected: length of result is 0
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0]]
query_ranges = [get_valid_range]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
assert status.OK()
assert len(result) == 1
assert result[0][0].distance <= epsilon
def test_search_distance_l2_flat_index(self, connect, table):
'''
target: search table, and check the result: distance
method: compare the return distance value with value computed with Euclidean
expected: the return distance equals to the computed value
'''
nb = 2
top_k = 1
nprobe = 1
vectors, ids = self.init_data(connect, table, nb=nb)
query_vecs = [[0.50 for i in range(dim)]]
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
def test_search_distance_ip_flat_index(self, connect, ip_table):
'''
target: search ip_table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
nb = 2
top_k = 1
nprobe = 1
vectors, ids = self.init_data(connect, ip_table, nb=nb)
index_params = {
"index_type": IndexType.FLAT,
"nlist": 16384
}
connect.create_index(ip_table, index_params)
logging.getLogger().info(connect.describe_index(ip_table))
query_vecs = [[0.50 for i in range(dim)]]
distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
def test_search_distance_ip_index_params(self, connect, ip_table, get_index_params):
'''
target: search table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, ip_table, nb=2)
index_params = get_index_params
connect.create_index(ip_table, index_params)
logging.getLogger().info(connect.describe_index(ip_table))
query_vecs = [[0.50 for i in range(dim)]]
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
# TODO: enable
# @pytest.mark.repeat(5)
@pytest.mark.timeout(30)
def _test_search_concurrent(self, connect, table):
vectors, ids = self.init_data(connect, table)
thread_num = 10
nb = 100
top_k = 10
threads = []
query_vecs = vectors[nb//2:nb]
def search():
status, result = connect.search_vectors(table, top_k, query_vecs)
assert len(result) == len(query_vecs)
for i in range(len(query_vecs)):
assert result[i][0].id in ids
assert result[i][0].distance == 0.0
for i in range(thread_num):
x = threading.Thread(target=search, args=())
threads.append(x)
x.start()
for th in threads:
th.join()
# TODO: enable
@pytest.mark.timeout(30)
def _test_search_concurrent_multiprocessing(self, args):
'''
target: test concurrent search with multiprocessess
method: search with 10 processes, each process uses dependent connection
expected: status ok and the returned vectors should be query_records
'''
nb = 100
top_k = 10
process_num = 4
processes = []
table = gen_unique_str("test_search_concurrent_multiprocessing")
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_type': IndexType.FLAT,
'store_raw_vector': False}
# create table
milvus = Milvus()
milvus.connect(uri=uri)
milvus.create_table(param)
vectors, ids = self.init_data(milvus, table, nb=nb)
query_vecs = vectors[nb//2:nb]
def search(milvus):
status, result = milvus.search_vectors(table, top_k, query_vecs)
assert len(result) == len(query_vecs)
for i in range(len(query_vecs)):
assert result[i][0].id in ids
assert result[i][0].distance == 0.0
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=search, args=(milvus, ))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
def test_search_multi_table_L2(search, args):
'''
target: test search multi tables of L2
method: add vectors into 10 tables, and search
expected: search status ok, the length of result
'''
num = 10
top_k = 10
nprobe = 1
tables = []
idx = []
for i in range(num):
table = gen_unique_str("test_add_multitable_%d" % i)
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_file_size': 10,
'metric_type': MetricType.L2}
# create table
milvus = Milvus()
milvus.connect(uri=uri)
milvus.create_table(param)
status, ids = milvus.add_vectors(table, vectors)
assert status.OK()
assert len(ids) == len(vectors)
tables.append(table)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
time.sleep(6)
query_vecs = [vectors[0], vectors[10], vectors[20]]
# start query from random table
for i in range(num):
table = tables[i]
status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for j in range(len(query_vecs)):
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
assert check_result(result[j], idx[3 * i + j])
def test_search_multi_table_IP(search, args):
'''
target: test search multi tables of IP
method: add vectors into 10 tables, and search
expected: search status ok, the length of result
'''
num = 10
top_k = 10
nprobe = 1
tables = []
idx = []
for i in range(num):
table = gen_unique_str("test_add_multitable_%d" % i)
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_file_size': 10,
'metric_type': MetricType.L2}
# create table
milvus = Milvus()
milvus.connect(uri=uri)
milvus.create_table(param)
status, ids = milvus.add_vectors(table, vectors)
assert status.OK()
assert len(ids) == len(vectors)
tables.append(table)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
time.sleep(6)
query_vecs = [vectors[0], vectors[10], vectors[20]]
# start query from random table
for i in range(num):
table = tables[i]
status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for j in range(len(query_vecs)):
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
assert check_result(result[j], idx[3 * i + j])
"""
******************************************************************
# The following cases are used to test `search_vectors` function
# with invalid table_name top-k / nprobe / query_range
******************************************************************
"""
class TestSearchParamsInvalid(object):
index_params = random.choice(gen_index_params())
logging.getLogger().info(index_params)
def init_data(self, connect, table, nb=100):
'''
Generate vectors and add it in table, before search vectors
'''
global vectors
if nb == 100:
add_vectors = vectors
else:
add_vectors = gen_vectors(nb, dim)
status, ids = connect.add_vectors(table, add_vectors)
sleep(add_interval_time)
return add_vectors, ids
"""
Test search table with invalid table names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_table_names()
)
def get_table_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_search_with_invalid_tablename(self, connect, get_table_name):
table_name = get_table_name
logging.getLogger().info(table_name)
top_k = 1
nprobe = 1
query_vecs = gen_vectors(1, dim)
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
assert not status.OK()
"""
Test search table with invalid top-k
"""
@pytest.fixture(
scope="function",
params=gen_invalid_top_ks()
)
def get_top_k(self, request):
yield request.param
@pytest.mark.level(2)
def test_search_with_invalid_top_k(self, connect, table, get_top_k):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = get_top_k
logging.getLogger().info(top_k)
nprobe = 1
query_vecs = gen_vectors(1, dim)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
res = connect.server_version()
@pytest.mark.level(2)
def test_search_with_invalid_top_k_ip(self, connect, ip_table, get_top_k):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = get_top_k
logging.getLogger().info(top_k)
nprobe = 1
query_vecs = gen_vectors(1, dim)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
res = connect.server_version()
"""
Test search table with invalid nprobe
"""
@pytest.fixture(
scope="function",
params=gen_invalid_nprobes()
)
def get_nprobes(self, request):
yield request.param
@pytest.mark.level(2)
def test_search_with_invalid_nrpobe(self, connect, table, get_nprobes):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = get_nprobes
logging.getLogger().info(nprobe)
query_vecs = gen_vectors(1, dim)
if isinstance(nprobe, int) and nprobe > 0:
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
@pytest.mark.level(2)
def test_search_with_invalid_nrpobe_ip(self, connect, ip_table, get_nprobes):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = get_nprobes
logging.getLogger().info(nprobe)
query_vecs = gen_vectors(1, dim)
if isinstance(nprobe, int) and nprobe > 0:
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
"""
Test search table with invalid query ranges
"""
@pytest.fixture(
scope="function",
params=gen_invalid_query_ranges()
)
def get_query_ranges(self, request):
yield request.param
@pytest.mark.level(2)
def test_search_flat_with_invalid_query_range(self, connect, table, get_query_ranges):
'''
target: test search fuction, with the wrong query_range
method: search with query_range
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
query_ranges = get_query_ranges
logging.getLogger().info(query_ranges)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table, 1, nprobe, query_vecs, query_ranges=query_ranges)
@pytest.mark.level(2)
def test_search_flat_with_invalid_query_range_ip(self, connect, ip_table, get_query_ranges):
'''
target: test search fuction, with the wrong query_range
method: search with query_range
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
query_ranges = get_query_ranges
logging.getLogger().info(query_ranges)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(ip_table, 1, nprobe, query_vecs, query_ranges=query_ranges)
def check_result(result, id):
if len(result) >= 5:
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
else:
return id in (i.id for i in result)

View File

@ -0,0 +1,883 @@
import random
import pdb
import pytest
import logging
import itertools
from time import sleep
from multiprocessing import Process
import numpy
from milvus import Milvus
from milvus import IndexType, MetricType
from utils import *
dim = 128
delete_table_interval_time = 3
index_file_size = 10
vectors = gen_vectors(100, dim)
class TestTable:
"""
******************************************************************
The following cases are used to test `create_table` function
******************************************************************
"""
def test_create_table(self, connect):
'''
target: test create normal table
method: create table with corrent params
expected: create status return ok
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
status = connect.create_table(param)
assert status.OK()
def test_create_table_ip(self, connect):
'''
target: test create normal table
method: create table with corrent params
expected: create status return ok
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
status = connect.create_table(param)
assert status.OK()
@pytest.mark.level(2)
def test_create_table_without_connection(self, dis_connect):
'''
target: test create table, without connection
method: create table with correct params, with a disconnected instance
expected: create raise exception
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
with pytest.raises(Exception) as e:
status = dis_connect.create_table(param)
def test_create_table_existed(self, connect):
'''
target: test create table but the table name have already existed
method: create table with the same table_name
expected: create status return not ok
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
status = connect.create_table(param)
status = connect.create_table(param)
assert not status.OK()
@pytest.mark.level(2)
def test_create_table_existed_ip(self, connect):
'''
target: test create table but the table name have already existed
method: create table with the same table_name
expected: create status return not ok
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
status = connect.create_table(param)
status = connect.create_table(param)
assert not status.OK()
def test_create_table_None(self, connect):
'''
target: test create table but the table name is None
method: create table, param table_name is None
expected: create raise error
'''
param = {'table_name': None,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
with pytest.raises(Exception) as e:
status = connect.create_table(param)
def test_create_table_no_dimension(self, connect):
'''
target: test create table with no dimension params
method: create table with corrent params
expected: create status return ok
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
with pytest.raises(Exception) as e:
status = connect.create_table(param)
def test_create_table_no_file_size(self, connect):
'''
target: test create table with no index_file_size params
method: create table with corrent params
expected: create status return ok, use default 1024
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'metric_type': MetricType.L2}
status = connect.create_table(param)
logging.getLogger().info(status)
status, result = connect.describe_table(table_name)
logging.getLogger().info(result)
assert result.index_file_size == 1024
def test_create_table_no_metric_type(self, connect):
'''
target: test create table with no metric_type params
method: create table with corrent params
expected: create status return ok, use default L2
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size}
status = connect.create_table(param)
status, result = connect.describe_table(table_name)
logging.getLogger().info(result)
assert result.metric_type == MetricType.L2
"""
******************************************************************
The following cases are used to test `describe_table` function
******************************************************************
"""
def test_table_describe_result(self, connect):
'''
target: test describe table created with correct params
method: create table, assert the value returned by describe method
expected: table_name equals with the table name created
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status, res = connect.describe_table(table_name)
assert res.table_name == table_name
assert res.metric_type == MetricType.L2
def test_table_describe_table_name_ip(self, connect):
'''
target: test describe table created with correct params
method: create table, assert the value returned by describe method
expected: table_name equals with the table name created
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
status, res = connect.describe_table(table_name)
assert res.table_name == table_name
assert res.metric_type == MetricType.IP
# TODO: enable
@pytest.mark.level(2)
def _test_table_describe_table_name_multiprocessing(self, connect, args):
'''
target: test describe table created with multiprocess
method: create table, assert the value returned by describe method
expected: table_name equals with the table name created
'''
table_name = gen_unique_str("test_table")
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
def describetable(milvus):
status, res = milvus.describe_table(table_name)
assert res.table_name == table_name
process_num = 4
processes = []
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=describetable, args=(milvus,))
processes.append(p)
p.start()
for p in processes:
p.join()
@pytest.mark.level(2)
def test_table_describe_without_connection(self, table, dis_connect):
'''
target: test describe table, without connection
method: describe table with correct params, with a disconnected instance
expected: describe raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.describe_table(table)
def test_table_describe_dimension(self, connect):
'''
target: test describe table created with correct params
method: create table, assert the dimention value returned by describe method
expected: dimention equals with dimention when created
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim+1,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status, res = connect.describe_table(table_name)
assert res.dimension == dim+1
"""
******************************************************************
The following cases are used to test `delete_table` function
******************************************************************
"""
def test_delete_table(self, connect, table):
'''
target: test delete table created with correct params
method: create table and then delete,
assert the value returned by delete method
expected: status ok, and no table in tables
'''
status = connect.delete_table(table)
assert not connect.has_table(table)
def test_delete_table_ip(self, connect, ip_table):
'''
target: test delete table created with correct params
method: create table and then delete,
assert the value returned by delete method
expected: status ok, and no table in tables
'''
status = connect.delete_table(ip_table)
assert not connect.has_table(ip_table)
@pytest.mark.level(2)
def test_table_delete_without_connection(self, table, dis_connect):
'''
target: test describe table, without connection
method: describe table with correct params, with a disconnected instance
expected: describe raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.delete_table(table)
def test_delete_table_not_existed(self, connect):
'''
target: test delete table not in index
method: delete all tables, and delete table again,
assert the value returned by delete method
expected: status not ok
'''
table_name = gen_unique_str("test_table")
status = connect.delete_table(table_name)
assert not status.code==0
def test_delete_table_repeatedly(self, connect):
'''
target: test delete table created with correct params
method: create table and delete new table repeatedly,
assert the value returned by delete method
expected: create ok and delete ok
'''
loops = 1
for i in range(loops):
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status = connect.delete_table(table_name)
time.sleep(1)
assert not connect.has_table(table_name)
def test_delete_create_table_repeatedly(self, connect):
'''
target: test delete and create the same table repeatedly
method: try to create the same table and delete repeatedly,
assert the value returned by delete method
expected: create ok and delete ok
'''
loops = 5
for i in range(loops):
table_name = "test_table"
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status = connect.delete_table(table_name)
time.sleep(2)
assert status.OK()
@pytest.mark.level(2)
def test_delete_create_table_repeatedly_ip(self, connect):
'''
target: test delete and create the same table repeatedly
method: try to create the same table and delete repeatedly,
assert the value returned by delete method
expected: create ok and delete ok
'''
loops = 5
for i in range(loops):
table_name = "test_table"
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
status = connect.delete_table(table_name)
time.sleep(2)
assert status.OK()
# TODO: enable
@pytest.mark.level(2)
def _test_delete_table_multiprocessing(self, args):
'''
target: test delete table with multiprocess
method: create table and then delete,
assert the value returned by delete method
expected: status ok, and no table in tables
'''
process_num = 6
processes = []
uri = "tcp://%s:%s" % (args["ip"], args["port"])
def deletetable(milvus):
status = milvus.delete_table(table)
# assert not status.code==0
assert milvus.has_table(table)
assert status.OK()
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=deletetable, args=(milvus,))
processes.append(p)
p.start()
for p in processes:
p.join()
# TODO: enable
@pytest.mark.level(2)
def _test_delete_table_multiprocessing_multitable(self, connect):
'''
target: test delete table with multiprocess
method: create table and then delete,
assert the value returned by delete method
expected: status ok, and no table in tables
'''
process_num = 5
loop_num = 2
processes = []
table = []
j = 0
while j < (process_num*loop_num):
table_name = gen_unique_str("test_delete_table_with_multiprocessing")
table.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
j = j + 1
def delete(connect,ids):
i = 0
while i < loop_num:
# assert connect.has_table(table[ids*8+i])
status = connect.delete_table(table[ids*process_num+i])
time.sleep(2)
assert status.OK()
assert not connect.has_table(table[ids*process_num+i])
i = i + 1
for i in range(process_num):
ids = i
p = Process(target=delete, args=(connect,ids))
processes.append(p)
p.start()
for p in processes:
p.join()
"""
******************************************************************
The following cases are used to test `has_table` function
******************************************************************
"""
def test_has_table(self, connect):
'''
target: test if the created table existed
method: create table, assert the value returned by has_table method
expected: True
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
assert connect.has_table(table_name)
def test_has_table_ip(self, connect):
'''
target: test if the created table existed
method: create table, assert the value returned by has_table method
expected: True
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
assert connect.has_table(table_name)
@pytest.mark.level(2)
def test_has_table_without_connection(self, table, dis_connect):
'''
target: test has table, without connection
method: calling has table with correct params, with a disconnected instance
expected: has table raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.has_table(table)
def test_has_table_not_existed(self, connect):
'''
target: test if table not created
method: random a table name, which not existed in db,
assert the value returned by has_table method
expected: False
'''
table_name = gen_unique_str("test_table")
assert not connect.has_table(table_name)
"""
******************************************************************
The following cases are used to test `show_tables` function
******************************************************************
"""
def test_show_tables(self, connect):
'''
target: test show tables is correct or not, if table created
method: create table, assert the value returned by show_tables method is equal to 0
expected: table_name in show tables
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
status, result = connect.show_tables()
assert status.OK()
assert table_name in result
def test_show_tables_ip(self, connect):
'''
target: test show tables is correct or not, if table created
method: create table, assert the value returned by show_tables method is equal to 0
expected: table_name in show tables
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
status, result = connect.show_tables()
assert status.OK()
assert table_name in result
@pytest.mark.level(2)
def test_show_tables_without_connection(self, dis_connect):
'''
target: test show_tables, without connection
method: calling show_tables with correct params, with a disconnected instance
expected: show_tables raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.show_tables()
def test_show_tables_no_table(self, connect):
'''
target: test show tables is correct or not, if no table in db
method: delete all tables,
assert the value returned by show_tables method is equal to []
expected: the status is ok, and the result is equal to []
'''
status, result = connect.show_tables()
if result:
for table_name in result:
connect.delete_table(table_name)
time.sleep(delete_table_interval_time)
status, result = connect.show_tables()
assert status.OK()
assert len(result) == 0
# TODO: enable
@pytest.mark.level(2)
def _test_show_tables_multiprocessing(self, connect, args):
'''
target: test show tables is correct or not with processes
method: create table, assert the value returned by show_tables method is equal to 0
expected: table_name in show tables
'''
table_name = gen_unique_str("test_table")
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
def showtables(milvus):
status, result = milvus.show_tables()
assert status.OK()
assert table_name in result
process_num = 8
processes = []
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=showtables, args=(milvus,))
processes.append(p)
p.start()
for p in processes:
p.join()
"""
******************************************************************
The following cases are used to test `preload_table` function
******************************************************************
"""
"""
generate valid create_index params
"""
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
@pytest.mark.level(1)
def test_preload_table(self, connect, table, get_index_params):
index_params = get_index_params
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
status = connect.preload_table(table)
assert status.OK()
@pytest.mark.level(1)
def test_preload_table_ip(self, connect, ip_table, get_index_params):
index_params = get_index_params
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
status = connect.preload_table(ip_table)
assert status.OK()
@pytest.mark.level(1)
def test_preload_table_not_existed(self, connect, table):
table_name = gen_unique_str("test_preload_table_not_existed")
index_params = random.choice(gen_index_params())
status, ids = connect.add_vectors(table, vectors)
status = connect.create_index(table, index_params)
status = connect.preload_table(table_name)
assert not status.OK()
@pytest.mark.level(1)
def test_preload_table_not_existed_ip(self, connect, ip_table):
table_name = gen_unique_str("test_preload_table_not_existed")
index_params = random.choice(gen_index_params())
status, ids = connect.add_vectors(ip_table, vectors)
status = connect.create_index(ip_table, index_params)
status = connect.preload_table(table_name)
assert not status.OK()
@pytest.mark.level(1)
def test_preload_table_no_vectors(self, connect, table):
status = connect.preload_table(table)
assert status.OK()
@pytest.mark.level(1)
def test_preload_table_no_vectors_ip(self, connect, ip_table):
status = connect.preload_table(ip_table)
assert status.OK()
# TODO: psutils get memory usage
@pytest.mark.level(1)
def test_preload_table_memory_usage(self, connect, table):
pass
class TestTableInvalid(object):
"""
Test creating table with invalid table names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_table_names()
)
def get_table_name(self, request):
yield request.param
def test_create_table_with_invalid_tablename(self, connect, get_table_name):
table_name = get_table_name
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
status = connect.create_table(param)
assert not status.OK()
def test_create_table_with_empty_tablename(self, connect):
table_name = ''
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
with pytest.raises(Exception) as e:
status = connect.create_table(param)
def test_preload_table_with_invalid_tablename(self, connect):
table_name = ''
with pytest.raises(Exception) as e:
status = connect.preload_table(table_name)
class TestCreateTableDimInvalid(object):
"""
Test creating table with invalid dimension
"""
@pytest.fixture(
scope="function",
params=gen_invalid_dims()
)
def get_dim(self, request):
yield request.param
@pytest.mark.timeout(5)
def test_create_table_with_invalid_dimension(self, connect, get_dim):
dimension = get_dim
table = gen_unique_str("test_create_table_with_invalid_dimension")
param = {'table_name': table,
'dimension': dimension,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
if isinstance(dimension, int) and dimension > 0:
status = connect.create_table(param)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status = connect.create_table(param)
# TODO: max / min index file size
class TestCreateTableIndexSizeInvalid(object):
"""
Test creating tables with invalid index_file_size
"""
@pytest.fixture(
scope="function",
params=gen_invalid_file_sizes()
)
def get_file_size(self, request):
yield request.param
@pytest.mark.level(2)
def test_create_table_with_invalid_file_size(self, connect, table, get_file_size):
file_size = get_file_size
param = {'table_name': table,
'dimension': dim,
'index_file_size': file_size,
'metric_type': MetricType.L2}
if isinstance(file_size, int) and file_size > 0:
status = connect.create_table(param)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status = connect.create_table(param)
class TestCreateMetricTypeInvalid(object):
"""
Test creating tables with invalid metric_type
"""
@pytest.fixture(
scope="function",
params=gen_invalid_metric_types()
)
def get_metric_type(self, request):
yield request.param
@pytest.mark.level(2)
def test_create_table_with_invalid_file_size(self, connect, table, get_metric_type):
metric_type = get_metric_type
param = {'table_name': table,
'dimension': dim,
'index_file_size': 10,
'metric_type': metric_type}
with pytest.raises(Exception) as e:
status = connect.create_table(param)
def create_table(connect, **params):
param = {'table_name': params["table_name"],
'dimension': params["dimension"],
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
status = connect.create_table(param)
return status
def search_table(connect, **params):
status, result = connect.search_vectors(
params["table_name"],
params["top_k"],
params["nprobe"],
params["query_vectors"])
return status
def preload_table(connect, **params):
status = connect.preload_table(params["table_name"])
return status
def has(connect, **params):
status = connect.has_table(params["table_name"])
return status
def show(connect, **params):
status, result = connect.show_tables()
return status
def delete(connect, **params):
status = connect.delete_table(params["table_name"])
return status
def describe(connect, **params):
status, result = connect.describe_table(params["table_name"])
return status
def rowcount(connect, **params):
status, result = connect.get_table_row_count(params["table_name"])
return status
def create_index(connect, **params):
status = connect.create_index(params["table_name"], params["index_params"])
return status
func_map = {
# 0:has,
1:show,
10:create_table,
11:describe,
12:rowcount,
13:search_table,
14:preload_table,
15:create_index,
30:delete
}
def gen_sequence():
raw_seq = func_map.keys()
result = itertools.permutations(raw_seq)
for x in result:
yield x
class TestTableLogic(object):
@pytest.mark.parametrize("logic_seq", gen_sequence())
@pytest.mark.level(2)
def test_logic(self, connect, logic_seq):
if self.is_right(logic_seq):
self.execute(logic_seq, connect)
else:
self.execute_with_error(logic_seq, connect)
def is_right(self, seq):
if sorted(seq) == seq:
return True
not_created = True
has_deleted = False
for i in range(len(seq)):
if seq[i] > 10 and not_created:
return False
elif seq [i] > 10 and has_deleted:
return False
elif seq[i] == 10:
not_created = False
elif seq[i] == 30:
has_deleted = True
return True
def execute(self, logic_seq, connect):
basic_params = self.gen_params()
for i in range(len(logic_seq)):
# logging.getLogger().info(logic_seq[i])
f = func_map[logic_seq[i]]
status = f(connect, **basic_params)
assert status.OK()
def execute_with_error(self, logic_seq, connect):
basic_params = self.gen_params()
error_flag = False
for i in range(len(logic_seq)):
f = func_map[logic_seq[i]]
status = f(connect, **basic_params)
if not status.OK():
# logging.getLogger().info(logic_seq[i])
error_flag = True
break
assert error_flag == True
def gen_params(self):
table_name = gen_unique_str("test_table")
top_k = 1
vectors = gen_vectors(2, dim)
param = {'table_name': table_name,
'dimension': dim,
'index_type': IndexType.IVFLAT,
'metric_type': MetricType.L2,
'nprobe': 1,
'top_k': top_k,
'index_params': {
'index_type': IndexType.IVF_SQ8,
'nlist': 16384
},
'query_vectors': vectors}
return param

View File

@ -0,0 +1,296 @@
import random
import pdb
import pytest
import logging
import itertools
from time import sleep
from multiprocessing import Process
from milvus import Milvus
from utils import *
from milvus import IndexType, MetricType
dim = 128
index_file_size = 10
add_time_interval = 5
class TestTableCount:
"""
params means different nb, the nb value may trigger merge, or not
"""
@pytest.fixture(
scope="function",
params=[
100,
5000,
100000,
],
)
def add_vectors_nb(self, request):
yield request.param
"""
generate valid create_index params
"""
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
def test_table_rows_count(self, connect, table, add_vectors_nb):
'''
target: test table rows_count is correct or not
method: create table and add vectors in it,
assert the value returned by get_table_row_count method is equal to length of vectors
expected: the count is equal to the length of vectors
'''
nb = add_vectors_nb
vectors = gen_vectors(nb, dim)
res = connect.add_vectors(table_name=table, records=vectors)
time.sleep(add_time_interval)
status, res = connect.get_table_row_count(table)
assert res == nb
def test_table_rows_count_after_index_created(self, connect, table, get_index_params):
'''
target: test get_table_row_count, after index have been created
method: add vectors in db, and create index, then calling get_table_row_count with correct params
expected: get_table_row_count raise exception
'''
nb = 100
index_params = get_index_params
vectors = gen_vectors(nb, dim)
res = connect.add_vectors(table_name=table, records=vectors)
time.sleep(add_time_interval)
# logging.getLogger().info(index_params)
connect.create_index(table, index_params)
status, res = connect.get_table_row_count(table)
assert res == nb
@pytest.mark.level(2)
def test_count_without_connection(self, table, dis_connect):
'''
target: test get_table_row_count, without connection
method: calling get_table_row_count with correct params, with a disconnected instance
expected: get_table_row_count raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.get_table_row_count(table)
def test_table_rows_count_no_vectors(self, connect, table):
'''
target: test table rows_count is correct or not, if table is empty
method: create table and no vectors in it,
assert the value returned by get_table_row_count method is equal to 0
expected: the count is equal to 0
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size}
connect.create_table(param)
status, res = connect.get_table_row_count(table)
assert res == 0
# TODO: enable
@pytest.mark.level(2)
@pytest.mark.timeout(20)
def _test_table_rows_count_multiprocessing(self, connect, table, args):
'''
target: test table rows_count is correct or not with multiprocess
method: create table and add vectors in it,
assert the value returned by get_table_row_count method is equal to length of vectors
expected: the count is equal to the length of vectors
'''
nq = 2
uri = "tcp://%s:%s" % (args["ip"], args["port"])
vectors = gen_vectors(nq, dim)
res = connect.add_vectors(table_name=table, records=vectors)
time.sleep(add_time_interval)
def rows_count(milvus):
status, res = milvus.get_table_row_count(table)
logging.getLogger().info(status)
assert res == nq
process_num = 8
processes = []
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=rows_count, args=(milvus, ))
processes.append(p)
p.start()
logging.getLogger().info(p)
for p in processes:
p.join()
def test_table_rows_count_multi_tables(self, connect):
'''
target: test table rows_count is correct or not with multiple tables of L2
method: create table and add vectors in it,
assert the value returned by get_table_row_count method is equal to length of vectors
expected: the count is equal to the length of vectors
'''
nq = 100
vectors = gen_vectors(nq, dim)
table_list = []
for i in range(50):
table_name = gen_unique_str('test_table_rows_count_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_table(param)
res = connect.add_vectors(table_name=table_name, records=vectors)
time.sleep(2)
for i in range(50):
status, res = connect.get_table_row_count(table_list[i])
assert status.OK()
assert res == nq
class TestTableCountIP:
"""
params means different nb, the nb value may trigger merge, or not
"""
@pytest.fixture(
scope="function",
params=[
100,
5000,
100000,
],
)
def add_vectors_nb(self, request):
yield request.param
"""
generate valid create_index params
"""
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request):
yield request.param
def test_table_rows_count(self, connect, ip_table, add_vectors_nb):
'''
target: test table rows_count is correct or not
method: create table and add vectors in it,
assert the value returned by get_table_row_count method is equal to length of vectors
expected: the count is equal to the length of vectors
'''
nb = add_vectors_nb
vectors = gen_vectors(nb, dim)
res = connect.add_vectors(table_name=ip_table, records=vectors)
time.sleep(add_time_interval)
status, res = connect.get_table_row_count(ip_table)
assert res == nb
def test_table_rows_count_after_index_created(self, connect, ip_table, get_index_params):
'''
target: test get_table_row_count, after index have been created
method: add vectors in db, and create index, then calling get_table_row_count with correct params
expected: get_table_row_count raise exception
'''
nb = 100
index_params = get_index_params
vectors = gen_vectors(nb, dim)
res = connect.add_vectors(table_name=ip_table, records=vectors)
time.sleep(add_time_interval)
# logging.getLogger().info(index_params)
connect.create_index(ip_table, index_params)
status, res = connect.get_table_row_count(ip_table)
assert res == nb
@pytest.mark.level(2)
def test_count_without_connection(self, ip_table, dis_connect):
'''
target: test get_table_row_count, without connection
method: calling get_table_row_count with correct params, with a disconnected instance
expected: get_table_row_count raise exception
'''
with pytest.raises(Exception) as e:
status = dis_connect.get_table_row_count(ip_table)
def test_table_rows_count_no_vectors(self, connect, ip_table):
'''
target: test table rows_count is correct or not, if table is empty
method: create table and no vectors in it,
assert the value returned by get_table_row_count method is equal to 0
expected: the count is equal to 0
'''
table_name = gen_unique_str("test_table")
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size}
connect.create_table(param)
status, res = connect.get_table_row_count(ip_table)
assert res == 0
# TODO: enable
@pytest.mark.level(2)
@pytest.mark.timeout(20)
def _test_table_rows_count_multiprocessing(self, connect, ip_table, args):
'''
target: test table rows_count is correct or not with multiprocess
method: create table and add vectors in it,
assert the value returned by get_table_row_count method is equal to length of vectors
expected: the count is equal to the length of vectors
'''
nq = 2
uri = "tcp://%s:%s" % (args["ip"], args["port"])
vectors = gen_vectors(nq, dim)
res = connect.add_vectors(table_name=ip_table, records=vectors)
time.sleep(add_time_interval)
def rows_count(milvus):
status, res = milvus.get_table_row_count(ip_table)
logging.getLogger().info(status)
assert res == nq
process_num = 8
processes = []
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=rows_count, args=(milvus,))
processes.append(p)
p.start()
logging.getLogger().info(p)
for p in processes:
p.join()
def test_table_rows_count_multi_tables(self, connect):
'''
target: test table rows_count is correct or not with multiple tables of IP
method: create table and add vectors in it,
assert the value returned by get_table_row_count method is equal to length of vectors
expected: the count is equal to the length of vectors
'''
nq = 100
vectors = gen_vectors(nq, dim)
table_list = []
for i in range(50):
table_name = gen_unique_str('test_table_rows_count_multi_tables')
table_list.append(table_name)
param = {'table_name': table_name,
'dimension': dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_table(param)
res = connect.add_vectors(table_name=table_name, records=vectors)
time.sleep(2)
for i in range(50):
status, res = connect.get_table_row_count(table_list[i])
assert status.OK()
assert res == nq

View File

@ -0,0 +1,545 @@
# STL imports
import random
import string
import struct
import sys
import time, datetime
import copy
import numpy as np
from utils import *
from milvus import Milvus, IndexType, MetricType
def gen_inaccuracy(num):
return num/255.0
def gen_vectors(num, dim):
return [[random.random() for _ in range(dim)] for _ in range(num)]
def gen_single_vector(dim):
return [[random.random() for _ in range(dim)]]
def gen_vector(nb, d, seed=np.random.RandomState(1234)):
xb = seed.rand(nb, d).astype("float32")
return xb.tolist()
def gen_unique_str(str=None):
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
return prefix if str is None else str + "_" + prefix
def get_current_day():
return time.strftime('%Y-%m-%d', time.localtime())
def get_last_day(day):
tmp = datetime.datetime.now()-datetime.timedelta(days=day)
return tmp.strftime('%Y-%m-%d')
def get_next_day(day):
tmp = datetime.datetime.now()+datetime.timedelta(days=day)
return tmp.strftime('%Y-%m-%d')
def gen_long_str(num):
string = ''
for _ in range(num):
char = random.choice('tomorrow')
string += char
def gen_invalid_ips():
ips = [
"255.0.0.0",
"255.255.0.0",
"255.255.255.0",
"255.255.255.255",
"127.0.0",
"123.0.0.2",
"12-s",
" ",
"12 s",
"BB。A",
" siede ",
"(mn)",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return ips
def gen_invalid_ports():
ports = [
# empty
" ",
-1,
# too big port
100000,
# not correct port
39540,
"BB。A",
" siede ",
"(mn)",
"\n",
"\t",
"中文"
]
return ports
def gen_invalid_uris():
ip = None
port = 19530
uris = [
" ",
"中文",
# invalid protocol
# "tc://%s:%s" % (ip, port),
# "tcp%s:%s" % (ip, port),
# # invalid port
# "tcp://%s:100000" % ip,
# "tcp://%s: " % ip,
# "tcp://%s:19540" % ip,
# "tcp://%s:-1" % ip,
# "tcp://%s:string" % ip,
# invalid ip
"tcp:// :%s" % port,
"tcp://123.0.0.1:%s" % port,
"tcp://127.0.0:%s" % port,
"tcp://255.0.0.0:%s" % port,
"tcp://255.255.0.0:%s" % port,
"tcp://255.255.255.0:%s" % port,
"tcp://255.255.255.255:%s" % port,
"tcp://\n:%s" % port,
]
return uris
def gen_invalid_table_names():
table_names = [
"12-s",
"12/s",
" ",
# "",
# None,
"12 s",
"BB。A",
"c|c",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return table_names
def gen_invalid_top_ks():
top_ks = [
0,
-1,
None,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return top_ks
def gen_invalid_dims():
dims = [
0,
-1,
100001,
1000000000000001,
None,
False,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return dims
def gen_invalid_file_sizes():
file_sizes = [
0,
-1,
1000000000000001,
None,
False,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return file_sizes
def gen_invalid_index_types():
invalid_types = [
0,
-1,
100,
1000000000000001,
# None,
False,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return invalid_types
def gen_invalid_nlists():
nlists = [
0,
-1,
1000000000000001,
# None,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文"
]
return nlists
def gen_invalid_nprobes():
nprobes = [
0,
-1,
1000000000000001,
None,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文"
]
return nprobes
def gen_invalid_metric_types():
metric_types = [
0,
-1,
1000000000000001,
# None,
[1,2,3],
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文"
]
return metric_types
def gen_invalid_vectors():
invalid_vectors = [
"1*2",
[],
[1],
[1,2],
[" "],
['a'],
[None],
None,
(1,2),
{"a": 1},
" ",
"",
"String",
"12-s",
"BB。A",
" siede ",
"(mn)",
"#12s",
"pip+",
"=c",
"\n",
"\t",
"中文",
"a".join("a" for i in range(256))
]
return invalid_vectors
def gen_invalid_vector_ids():
invalid_vector_ids = [
1.0,
-1.0,
None,
# int 64
10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
" ",
"",
"String",
"BB。A",
" siede ",
"(mn)",
"#12s",
"=c",
"\n",
"中文",
]
return invalid_vector_ids
def gen_invalid_query_ranges():
query_ranges = [
[(get_last_day(1), "")],
[(get_current_day(), "")],
[(get_next_day(1), "")],
[(get_current_day(), get_last_day(1))],
[(get_next_day(1), get_last_day(1))],
[(get_next_day(1), get_current_day())],
[(0, get_next_day(1))],
[(-1, get_next_day(1))],
[(1, get_next_day(1))],
[(100001, get_next_day(1))],
[(1000000000000001, get_next_day(1))],
[(None, get_next_day(1))],
[([1,2,3], get_next_day(1))],
[((1,2), get_next_day(1))],
[({"a": 1}, get_next_day(1))],
[(" ", get_next_day(1))],
[("", get_next_day(1))],
[("String", get_next_day(1))],
[("12-s", get_next_day(1))],
[("BB。A", get_next_day(1))],
[(" siede ", get_next_day(1))],
[("(mn)", get_next_day(1))],
[("#12s", get_next_day(1))],
[("pip+", get_next_day(1))],
[("=c", get_next_day(1))],
[("\n", get_next_day(1))],
[("\t", get_next_day(1))],
[("中文", get_next_day(1))],
[("a".join("a" for i in range(256)), get_next_day(1))]
]
return query_ranges
def gen_invalid_index_params():
index_params = []
for index_type in gen_invalid_index_types():
index_param = {"index_type": index_type, "nlist": 16384}
index_params.append(index_param)
for nlist in gen_invalid_nlists():
index_param = {"index_type": IndexType.IVFLAT, "nlist": nlist}
index_params.append(index_param)
return index_params
def gen_index_params():
index_params = []
index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
nlists = [1, 16384, 50000]
def gen_params(index_types, nlists):
return [ {"index_type": index_type, "nlist": nlist} \
for index_type in index_types \
for nlist in nlists]
return gen_params(index_types, nlists)
def gen_simple_index_params():
index_params = []
index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
nlists = [16384]
def gen_params(index_types, nlists):
return [ {"index_type": index_type, "nlist": nlist} \
for index_type in index_types \
for nlist in nlists]
return gen_params(index_types, nlists)
if __name__ == "__main__":
import numpy
dim = 128
nq = 10000
table = "test"
file_name = '/poc/yuncong/ann_1000m/query.npy'
data = np.load(file_name)
vectors = data[0:nq].tolist()
# print(vectors)
connect = Milvus()
# connect.connect(host="192.168.1.27")
# print(connect.show_tables())
# print(connect.get_table_row_count(table))
# sys.exit()
connect.connect(host="127.0.0.1")
connect.delete_table(table)
# sys.exit()
# time.sleep(2)
print(connect.get_table_row_count(table))
param = {'table_name': table,
'dimension': dim,
'metric_type': MetricType.L2,
'index_file_size': 10}
status = connect.create_table(param)
print(status)
print(connect.get_table_row_count(table))
# add vectors
for i in range(10):
status, ids = connect.add_vectors(table, vectors)
print(status)
print(ids[0])
# print(ids[0])
index_params = {"index_type": IndexType.IVFLAT, "nlist": 16384}
status = connect.create_index(table, index_params)
print(status)
# sys.exit()
query_vec = [vectors[0]]
# print(numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0])))
top_k = 12
nprobe = 1
for i in range(2):
result = connect.search_vectors(table, top_k, nprobe, query_vec)
print(result)
sys.exit()
table = gen_unique_str("test_add_vector_with_multiprocessing")
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_file_size': index_file_size}
# create table
milvus = Milvus()
milvus.connect(uri=uri)
milvus.create_table(param)
vector = gen_single_vector(dim)
process_num = 4
loop_num = 10
processes = []
# with dependent connection
def add(milvus):
i = 0
while i < loop_num:
status, ids = milvus.add_vectors(table, vector)
i = i + 1
for i in range(process_num):
milvus = Milvus()
milvus.connect(uri=uri)
p = Process(target=add, args=(milvus,))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
time.sleep(3)
status, count = milvus.get_table_row_count(table)
assert count == process_num * loop_num