mirror of https://github.com/milvus-io/milvus.git
Merge remote-tracking branch 'source/0.5.0' into branch-0.5.0
Former-commit-id: 1383a784ce89252082a752ac91cfb6242428cbdapull/191/head
commit
fb8c3b0753
|
@ -26,6 +26,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
- MS-653 - When config check fail, Milvus close without message
|
||||
- MS-654 - Describe index timeout when building index
|
||||
- MS-658 - Fix SQ8 Hybrid can't search
|
||||
- \#9 Change default gpu_cache_capacity to 4
|
||||
- MS-665 - IVF_SQ8H search crash when no GPU resource in search_resources
|
||||
- \#20 - C++ sdk example get grpc error
|
||||
- \#23 - Add unittest to improve code coverage
|
||||
|
@ -75,6 +76,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
|||
- MS-624 - Re-organize project directory for open-source
|
||||
- MS-635 - Add compile option to support customized faiss
|
||||
- MS-660 - add ubuntu_build_deps.sh
|
||||
- \#18 - Add all test cases
|
||||
|
||||
# Milvus 0.4.0 (2019-09-12)
|
||||
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
target/
|
||||
.idea/
|
||||
test-output/
|
||||
lib/*
|
|
@ -0,0 +1,10 @@
|
|||
def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) {
|
||||
if (protocol == "ftp") {
|
||||
ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [
|
||||
[configName: "${remoteIP}", transfers: [
|
||||
[asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
return this
|
|
@ -0,0 +1,13 @@
|
|||
try {
|
||||
def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true
|
||||
if (!result) {
|
||||
sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
|
||||
}
|
||||
} catch (exc) {
|
||||
def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true
|
||||
if (!result) {
|
||||
sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
|
||||
}
|
||||
throw exc
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
try {
|
||||
sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts'
|
||||
sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus'
|
||||
sh 'helm repo update'
|
||||
dir ("milvus-helm") {
|
||||
checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]])
|
||||
dir ("milvus/milvus-gpu") {
|
||||
sh "helm install --wait --timeout 300 --set engine.image.tag=${IMAGE_TAG} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-sdk-test --version 0.3.1 ."
|
||||
}
|
||||
}
|
||||
} catch (exc) {
|
||||
echo 'Helm running failed!'
|
||||
sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
|
||||
throw exc
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
timeout(time: 30, unit: 'MINUTES') {
|
||||
try {
|
||||
dir ("milvus-java-test") {
|
||||
sh "mvn clean install"
|
||||
sh "java -cp \"target/MilvusSDkJavaTest-1.0-SNAPSHOT.jar:lib/*\" com.MainClass -h ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-sdk-test.svc.cluster.local"
|
||||
}
|
||||
|
||||
} catch (exc) {
|
||||
echo 'Milvus-SDK-Java Integration Test Failed !'
|
||||
throw exc
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
def notify() {
|
||||
if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
|
||||
// Send an email only if the build status has changed from green/unstable to red
|
||||
emailext subject: '$DEFAULT_SUBJECT',
|
||||
body: '$DEFAULT_CONTENT',
|
||||
recipientProviders: [
|
||||
[$class: 'DevelopersRecipientProvider'],
|
||||
[$class: 'RequesterRecipientProvider']
|
||||
],
|
||||
replyTo: '$DEFAULT_REPLYTO',
|
||||
to: '$DEFAULT_RECIPIENTS'
|
||||
}
|
||||
}
|
||||
return this
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
timeout(time: 5, unit: 'MINUTES') {
|
||||
dir ("${PROJECT_NAME}_test") {
|
||||
if (fileExists('test_out')) {
|
||||
def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy"
|
||||
fileTransfer.FileTransfer("test_out/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage')
|
||||
if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
|
||||
echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\""
|
||||
}
|
||||
} else {
|
||||
error("Milvus Dev Test Out directory don't exists!")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
pipeline {
|
||||
agent none
|
||||
|
||||
options {
|
||||
timestamps()
|
||||
}
|
||||
|
||||
environment {
|
||||
SRC_BRANCH = "master"
|
||||
IMAGE_TAG = "${params.IMAGE_TAG}-release"
|
||||
HELM_BRANCH = "${params.IMAGE_TAG}"
|
||||
TEST_URL = "git@192.168.1.105:Test/milvus-java-test.git"
|
||||
TEST_BRANCH = "${params.IMAGE_TAG}"
|
||||
}
|
||||
|
||||
stages {
|
||||
stage("Setup env") {
|
||||
agent {
|
||||
kubernetes {
|
||||
label 'dev-test'
|
||||
defaultContainer 'jnlp'
|
||||
yaml """
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: milvus
|
||||
componet: test
|
||||
spec:
|
||||
containers:
|
||||
- name: milvus-testframework-java
|
||||
image: registry.zilliz.com/milvus/milvus-java-test:v0.1
|
||||
command:
|
||||
- cat
|
||||
tty: true
|
||||
volumeMounts:
|
||||
- name: kubeconf
|
||||
mountPath: /root/.kube/
|
||||
readOnly: true
|
||||
volumes:
|
||||
- name: kubeconf
|
||||
secret:
|
||||
secretName: test-cluster-config
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage("Deploy Server") {
|
||||
steps {
|
||||
gitlabCommitStatus(name: 'Deloy Server') {
|
||||
container('milvus-testframework-java') {
|
||||
script {
|
||||
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/deploy_server.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Integration Test") {
|
||||
steps {
|
||||
gitlabCommitStatus(name: 'Integration Test') {
|
||||
container('milvus-testframework-java') {
|
||||
script {
|
||||
print "In integration test stage"
|
||||
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/integration_test.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage ("Cleanup Env") {
|
||||
steps {
|
||||
gitlabCommitStatus(name: 'Cleanup Env') {
|
||||
container('milvus-testframework-java') {
|
||||
script {
|
||||
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
post {
|
||||
always {
|
||||
container('milvus-testframework-java') {
|
||||
script {
|
||||
load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy"
|
||||
}
|
||||
}
|
||||
}
|
||||
success {
|
||||
script {
|
||||
echo "Milvus java-sdk test success !"
|
||||
}
|
||||
}
|
||||
aborted {
|
||||
script {
|
||||
echo "Milvus java-sdk test aborted !"
|
||||
}
|
||||
}
|
||||
failure {
|
||||
script {
|
||||
echo "Milvus java-sdk test failed !"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: milvus
|
||||
componet: testframework-java
|
||||
spec:
|
||||
containers:
|
||||
- name: milvus-testframework-java
|
||||
image: maven:3.6.2-jdk-8
|
||||
command:
|
||||
- cat
|
||||
tty: true
|
|
@ -0,0 +1,2 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="JAVA_MODULE" version="4" />
|
|
@ -0,0 +1,137 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>milvus</groupId>
|
||||
<artifactId>MilvusSDkJavaTest</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-dependency-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>copy-dependencies</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>copy-dependencies</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<outputDirectory>lib</outputDirectory>
|
||||
<overWriteReleases>false</overWriteReleases>
|
||||
<overWriteSnapshots>true</overWriteSnapshots>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<grpc.version>1.23.0</grpc.version><!-- CURRENT_GRPC_VERSION -->
|
||||
<protobuf.version>3.9.0</protobuf.version>
|
||||
<protoc.version>3.9.0</protoc.version>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<!-- <dependencyManagement>-->
|
||||
<!-- <dependencies>-->
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.grpc</groupId>-->
|
||||
<!-- <artifactId>grpc-bom</artifactId>-->
|
||||
<!-- <version>${grpc.version}</version>-->
|
||||
<!-- <type>pom</type>-->
|
||||
<!-- <scope>import</scope>-->
|
||||
<!-- </dependency>-->
|
||||
<!-- </dependencies>-->
|
||||
<!-- </dependencyManagement>-->
|
||||
|
||||
<repositories>
|
||||
<repository>
|
||||
<id>oss.sonatype.org-snapshot</id>
|
||||
<url>http://oss.sonatype.org/content/repositories/snapshots</url>
|
||||
<releases>
|
||||
<enabled>false</enabled>
|
||||
</releases>
|
||||
<snapshots>
|
||||
<enabled>true</enabled>
|
||||
</snapshots>
|
||||
</repository>
|
||||
</repositories>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
<version>1.3</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.10</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.9</version>
|
||||
</dependency>
|
||||
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.milvus</groupId>-->
|
||||
<!-- <artifactId>milvus-sdk-java</artifactId>-->
|
||||
<!-- <version>0.1.0</version>-->
|
||||
<!-- </dependency>-->
|
||||
|
||||
<dependency>
|
||||
<groupId>io.milvus</groupId>
|
||||
<artifactId>milvus-sdk-java</artifactId>
|
||||
<version>0.1.1-SNAPSHOT</version>
|
||||
</dependency>
|
||||
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.grpc</groupId>-->
|
||||
<!-- <artifactId>grpc-netty-shaded</artifactId>-->
|
||||
<!-- <scope>runtime</scope>-->
|
||||
<!-- </dependency>-->
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.grpc</groupId>-->
|
||||
<!-- <artifactId>grpc-protobuf</artifactId>-->
|
||||
<!-- </dependency>-->
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.grpc</groupId>-->
|
||||
<!-- <artifactId>grpc-stub</artifactId>-->
|
||||
<!-- </dependency>-->
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>javax.annotation</groupId>-->
|
||||
<!-- <artifactId>javax.annotation-api</artifactId>-->
|
||||
<!-- <version>1.2</version>-->
|
||||
<!-- <scope>provided</scope> <!– not needed at runtime –>-->
|
||||
<!-- </dependency>-->
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>io.grpc</groupId>-->
|
||||
<!-- <artifactId>grpc-testing</artifactId>-->
|
||||
<!-- <scope>test</scope>-->
|
||||
<!-- </dependency>-->
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>com.google.protobuf</groupId>-->
|
||||
<!-- <artifactId>protobuf-java-util</artifactId>-->
|
||||
<!-- <version>${protobuf.version}</version>-->
|
||||
<!-- </dependency>-->
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,147 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.apache.commons.cli.*;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.SkipException;
|
||||
import org.testng.TestNG;
|
||||
import org.testng.annotations.DataProvider;
|
||||
import org.testng.xml.XmlClass;
|
||||
import org.testng.xml.XmlSuite;
|
||||
import org.testng.xml.XmlTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MainClass {
|
||||
private static String host = "127.0.0.1";
|
||||
private static String port = "19530";
|
||||
public Integer index_file_size = 50;
|
||||
public Integer dimension = 128;
|
||||
|
||||
public static void setHost(String host) {
|
||||
MainClass.host = host;
|
||||
}
|
||||
|
||||
public static void setPort(String port) {
|
||||
MainClass.port = port;
|
||||
}
|
||||
|
||||
@DataProvider(name="DefaultConnectArgs")
|
||||
public static Object[][] defaultConnectArgs(){
|
||||
return new Object[][]{{host, port}};
|
||||
}
|
||||
|
||||
@DataProvider(name="ConnectInstance")
|
||||
public Object[][] connectInstance(){
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
String tableName = RandomStringUtils.randomAlphabetic(10);
|
||||
return new Object[][]{{client, tableName}};
|
||||
}
|
||||
|
||||
@DataProvider(name="DisConnectInstance")
|
||||
public Object[][] disConnectInstance(){
|
||||
// Generate connection instance
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
try {
|
||||
client.disconnect();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
String tableName = RandomStringUtils.randomAlphabetic(10);
|
||||
return new Object[][]{{client, tableName}};
|
||||
}
|
||||
|
||||
@DataProvider(name="Table")
|
||||
public Object[][] provideTable(){
|
||||
Object[][] tables = new Object[2][2];
|
||||
MetricType metricTypes[] = { MetricType.L2, MetricType.IP };
|
||||
for (Integer i = 0; i < metricTypes.length; ++i) {
|
||||
String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
|
||||
// Generate connection instance
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(metricTypes[i])
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
Response res = client.createTable(tableSchemaParam);
|
||||
if (!res.ok()) {
|
||||
System.out.println(res.getMessage());
|
||||
throw new SkipException("Table created failed");
|
||||
}
|
||||
tables[i] = new Object[]{client, tableName};
|
||||
}
|
||||
return tables;
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
CommandLineParser parser = new DefaultParser();
|
||||
Options options = new Options();
|
||||
options.addOption("h", "host", true, "milvus-server hostname/ip");
|
||||
options.addOption("p", "port", true, "milvus-server port");
|
||||
try {
|
||||
CommandLine cmd = parser.parse(options, args);
|
||||
String host = cmd.getOptionValue("host");
|
||||
if (host != null) {
|
||||
setHost(host);
|
||||
}
|
||||
String port = cmd.getOptionValue("port");
|
||||
if (port != null) {
|
||||
setPort(port);
|
||||
}
|
||||
System.out.println("Host: "+host+", Port: "+port);
|
||||
}
|
||||
catch(ParseException exp) {
|
||||
System.err.println("Parsing failed. Reason: " + exp.getMessage() );
|
||||
}
|
||||
|
||||
// TestListenerAdapter tla = new TestListenerAdapter();
|
||||
// TestNG testng = new TestNG();
|
||||
// testng.setTestClasses(new Class[] { TestPing.class });
|
||||
// testng.setTestClasses(new Class[] { TestConnect.class });
|
||||
// testng.addListener(tla);
|
||||
// testng.run();
|
||||
|
||||
XmlSuite suite = new XmlSuite();
|
||||
suite.setName("TmpSuite");
|
||||
|
||||
XmlTest test = new XmlTest(suite);
|
||||
test.setName("TmpTest");
|
||||
List<XmlClass> classes = new ArrayList<XmlClass>();
|
||||
|
||||
classes.add(new XmlClass("com.TestPing"));
|
||||
classes.add(new XmlClass("com.TestAddVectors"));
|
||||
classes.add(new XmlClass("com.TestConnect"));
|
||||
classes.add(new XmlClass("com.TestDeleteVectors"));
|
||||
classes.add(new XmlClass("com.TestIndex"));
|
||||
classes.add(new XmlClass("com.TestSearchVectors"));
|
||||
classes.add(new XmlClass("com.TestTable"));
|
||||
classes.add(new XmlClass("com.TestTableCount"));
|
||||
|
||||
test.setXmlClasses(classes) ;
|
||||
|
||||
List<XmlSuite> suites = new ArrayList<XmlSuite>();
|
||||
suites.add(suite);
|
||||
TestNG tng = new TestNG();
|
||||
tng.setXmlSuites(suites);
|
||||
tng.run();
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.InsertParam;
|
||||
import io.milvus.client.InsertResponse;
|
||||
import io.milvus.client.MilvusClient;
|
||||
import io.milvus.client.TableParam;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class TestAddVectors {
|
||||
int dimension = 128;
|
||||
|
||||
public List<List<Float>> gen_vectors(Integer nb) {
|
||||
List<List<Float>> xb = new LinkedList<>();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
LinkedList<Float> vector = new LinkedList<>();
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
xb.add(vector);
|
||||
}
|
||||
return xb;
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
String tableNameNew = tableName + "_";
|
||||
InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 100;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Thread.currentThread().sleep(1000);
|
||||
// Assert table row count
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_timeout(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 200000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
System.out.println(new Date());
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(1).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_big_data(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 500000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
System.out.println(new Date());
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_with_ids(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
// Add vectors with ids
|
||||
List<Long> vectorIds;
|
||||
vectorIds = Stream.iterate(0L, n -> n)
|
||||
.limit(nb)
|
||||
.collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Thread.currentThread().sleep(1000);
|
||||
// Assert table row count
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
|
||||
}
|
||||
|
||||
// TODO: MS-628
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_with_invalid_ids(MilvusClient client, String tableName) {
|
||||
int nb = 10;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
// Add vectors with ids
|
||||
List<Long> vectorIds;
|
||||
vectorIds = Stream.iterate(0L, n -> n)
|
||||
.limit(nb+1)
|
||||
.collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_with_invalid_dimension(MilvusClient client, String tableName) {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
vectors.get(0).add((float) 0);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_with_invalid_vectors(MilvusClient client, String tableName) {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
vectors.set(0, new ArrayList<>());
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_repeatably(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 100000;
|
||||
int loops = 10;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = null;
|
||||
for (int i = 0; i < loops; ++i ) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
res = client.insert(insertParam);
|
||||
long endTime = System.currentTimeMillis();
|
||||
System.out.println("Total execution time: " + (endTime-startTime) + "ms");
|
||||
}
|
||||
Thread.currentThread().sleep(1000);
|
||||
// Assert table row count
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb * loops);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.ConnectParam;
|
||||
import io.milvus.client.MilvusClient;
|
||||
import io.milvus.client.MilvusGrpcClient;
|
||||
import io.milvus.client.Response;
|
||||
import org.testng.annotations.DataProvider;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class TestConnect {
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void test_connect(String host, String port){
|
||||
System.out.println("Host: "+host+", Port: "+port);
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
Response res = client.connect(connectParam);
|
||||
assert(res.ok());
|
||||
assert(client.connected());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void test_connect_repeat(String host, String port){
|
||||
MilvusGrpcClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
Response res = client.connect(connectParam);
|
||||
assert(!res.ok());
|
||||
assert(client.connected());
|
||||
}
|
||||
|
||||
@Test(dataProvider="InvalidConnectArgs")
|
||||
public void test_connect_invalid_connect_args(String ip, String port) throws InterruptedException {
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(ip)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
assert(!client.connected());
|
||||
}
|
||||
|
||||
// TODO: MS-615
|
||||
@DataProvider(name="InvalidConnectArgs")
|
||||
public Object[][] generate_invalid_connect_args() {
|
||||
String port = "19530";
|
||||
String ip = "";
|
||||
return new Object[][]{
|
||||
{"1.1.1.1", port},
|
||||
{"255.255.0.0", port},
|
||||
{"1.2.2", port},
|
||||
{"中文", port},
|
||||
{"www.baidu.com", "100000"},
|
||||
{"127.0.0.1", "100000"},
|
||||
{"www.baidu.com", "80"},
|
||||
};
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_disconnect(MilvusClient client, String tableName){
|
||||
assert(!client.connected());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_disconnect_repeatably(MilvusClient client, String tableNam){
|
||||
Response res = null;
|
||||
try {
|
||||
res = client.disconnect();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
assert(res.ok());
|
||||
assert(!client.connected());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class TestDeleteVectors {
|
||||
int index_file_size = 50;
|
||||
int dimension = 128;
|
||||
|
||||
public List<List<Float>> gen_vectors(Integer nb) {
|
||||
List<List<Float>> xb = new LinkedList<>();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
LinkedList<Float> vector = new LinkedList<>();
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
xb.add(vector);
|
||||
}
|
||||
return xb;
|
||||
}
|
||||
|
||||
public static Date getDeltaDate(int delta) {
|
||||
Date today = new Date();
|
||||
Calendar c = Calendar.getInstance();
|
||||
c.setTime(today);
|
||||
c.add(Calendar.DAY_OF_MONTH, delta);
|
||||
return c.getTime();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
// Add vectors
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Thread.sleep(1000);
|
||||
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(res_delete.ok());
|
||||
Thread.sleep(1000);
|
||||
// Assert table row count
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
String tableNameNew = tableName + "_";
|
||||
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableNameNew).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(!res_delete.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(!res_delete.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors_table_empty(MilvusClient client, String tableName) throws InterruptedException {
|
||||
DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(res_delete.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors_invalid_date_range(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 100;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
// Add vectors
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Thread.sleep(1000);
|
||||
DateRange dateRange = new DateRange(getDeltaDate(1), getDeltaDate(0));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(!res_delete.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors_invalid_date_range_1(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 100;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
DateRange dateRange = new DateRange(getDeltaDate(2), getDeltaDate(-1));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(!res_delete.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_delete_vectors_no_result(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 100;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
InsertResponse res = client.insert(insertParam);
|
||||
assert(res.getResponse().ok());
|
||||
Thread.sleep(1000);
|
||||
DateRange dateRange = new DateRange(getDeltaDate(-3), getDeltaDate(-2));
|
||||
DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
|
||||
Response res_delete = client.deleteByRange(param);
|
||||
assert(res_delete.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,340 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public class TestIndex {
|
||||
int index_file_size = 10;
|
||||
int dimension = 128;
|
||||
int n_list = 1024;
|
||||
int default_n_list = 16384;
|
||||
int nb = 100000;
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
IndexType defaultIndexType = IndexType.FLAT;
|
||||
|
||||
public List<List<Float>> gen_vectors(Integer nb) {
|
||||
List<List<Float>> xb = new LinkedList<>();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
LinkedList<Float> vector = new LinkedList<>();
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
xb.add(vector);
|
||||
}
|
||||
return xb;
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_repeatably(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getNList(), n_list);
|
||||
Assert.assertEquals(index1.getIndexType(), indexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_FLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.FLAT;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getIndexType(), indexType);
|
||||
}
|
||||
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 500000;
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
System.out.println(new Date());
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).withTimeout(1).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(!res_create.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVFLAT;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getIndexType(), indexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getIndexType(), indexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_IVFSQ8H(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8_H;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getIndexType(), indexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_with_no_vector(MilvusClient client, String tableName) {
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
String tableNameNew = tableName + "_";
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableNameNew).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(!res_create.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(!res_create.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_create_index_invalid_n_list(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int n_list = 0;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(!res_create.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_describe_index(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getNList(), n_list);
|
||||
Assert.assertEquals(index1.getIndexType(), indexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_alter_index(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
// Create another index
|
||||
IndexType indexTypeNew = IndexType.IVFLAT;
|
||||
int n_list_new = n_list + 1;
|
||||
Index index_new = new Index.Builder().withIndexType(indexTypeNew)
|
||||
.withNList(n_list_new)
|
||||
.build();
|
||||
CreateIndexParam createIndexParamNew = new CreateIndexParam.Builder(tableName).withIndex(index_new).build();
|
||||
Response res_create_new = client.createIndex(createIndexParamNew);
|
||||
assert(res_create_new.ok());
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res_create.ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getNList(), n_list_new);
|
||||
Assert.assertEquals(index1.getIndexType(), indexTypeNew);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_describe_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
String tableNameNew = tableName + "_";
|
||||
TableParam tableParam = new TableParam.Builder(tableNameNew).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_describe_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_drop_index(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(defaultIndexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response res_drop = client.dropIndex(tableParam);
|
||||
assert(res_drop.ok());
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getNList(), default_n_list);
|
||||
Assert.assertEquals(index1.getIndexType(), defaultIndexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_drop_index_repeatably(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(defaultIndexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
Response res_create = client.createIndex(createIndexParam);
|
||||
assert(res_create.ok());
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response res_drop = client.dropIndex(tableParam);
|
||||
res_drop = client.dropIndex(tableParam);
|
||||
assert(res_drop.ok());
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getNList(), default_n_list);
|
||||
Assert.assertEquals(index1.getIndexType(), defaultIndexType);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_drop_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
String tableNameNew = tableName + "_";
|
||||
TableParam tableParam = new TableParam.Builder(tableNameNew).build();
|
||||
Response res_drop = client.dropIndex(tableParam);
|
||||
assert(!res_drop.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_drop_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response res_drop = client.dropIndex(tableParam);
|
||||
assert(!res_drop.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_drop_index_no_index_created(MilvusClient client, String tableName) throws InterruptedException {
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response res_drop = client.dropIndex(tableParam);
|
||||
assert(res_drop.ok());
|
||||
DescribeIndexResponse res = client.describeIndex(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Index index1 = res.getIndex().get();
|
||||
Assert.assertEquals(index1.getNList(), default_n_list);
|
||||
Assert.assertEquals(index1.getIndexType(), defaultIndexType);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.ForkJoinPool;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TestMix {
|
||||
private int dimension = 128;
|
||||
private int nb = 100000;
|
||||
int n_list = 8192;
|
||||
int n_probe = 20;
|
||||
int top_k = 10;
|
||||
double epsilon = 0.001;
|
||||
int index_file_size = 20;
|
||||
|
||||
public List<Float> normalize(List<Float> w2v){
|
||||
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
|
||||
final float norm = (float) Math.sqrt(squareSum);
|
||||
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
|
||||
return w2v;
|
||||
}
|
||||
|
||||
public List<List<Float>> gen_vectors(int nb, boolean norm) {
|
||||
List<List<Float>> xb = new ArrayList<>();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
if (norm == true) {
|
||||
vector = normalize(vector);
|
||||
}
|
||||
xb.add(vector);
|
||||
}
|
||||
return xb;
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int thread_num = 10;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
ForkJoinPool executor = new ForkJoinPool();
|
||||
for (int i = 0; i < thread_num; i++) {
|
||||
executor.execute(
|
||||
() -> {
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
});
|
||||
}
|
||||
executor.awaitQuiescence(100, TimeUnit.SECONDS);
|
||||
executor.shutdown();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void test_connect_threads(String host, String port) throws InterruptedException {
|
||||
int thread_num = 100;
|
||||
ForkJoinPool executor = new ForkJoinPool();
|
||||
for (int i = 0; i < thread_num; i++) {
|
||||
executor.execute(
|
||||
() -> {
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
assert(client.connected());
|
||||
try {
|
||||
client.disconnect();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
});
|
||||
}
|
||||
executor.awaitQuiescence(100, TimeUnit.SECONDS);
|
||||
executor.shutdown();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int thread_num = 10;
|
||||
List<List<Float>> vectors = gen_vectors(nb,false);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
ForkJoinPool executor = new ForkJoinPool();
|
||||
for (int i = 0; i < thread_num; i++) {
|
||||
executor.execute(
|
||||
() -> {
|
||||
InsertResponse res_insert = client.insert(insertParam);
|
||||
assert (res_insert.getResponse().ok());
|
||||
});
|
||||
}
|
||||
executor.awaitQuiescence(100, TimeUnit.SECONDS);
|
||||
executor.shutdown();
|
||||
|
||||
Thread.sleep(2000);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
|
||||
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int thread_num = 50;
|
||||
List<List<Float>> vectors = gen_vectors(nb,false);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
ForkJoinPool executor = new ForkJoinPool();
|
||||
for (int i = 0; i < thread_num; i++) {
|
||||
executor.execute(
|
||||
() -> {
|
||||
InsertResponse res_insert = client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
assert (res_insert.getResponse().ok());
|
||||
});
|
||||
}
|
||||
executor.awaitQuiescence(300, TimeUnit.SECONDS);
|
||||
executor.shutdown();
|
||||
Thread.sleep(2000);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
|
||||
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int thread_num = 50;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, true);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
ForkJoinPool executor = new ForkJoinPool();
|
||||
for (int i = 0; i < thread_num; i++) {
|
||||
executor.execute(
|
||||
() -> {
|
||||
InsertResponse res_insert = client.insert(insertParam);
|
||||
assert (res_insert.getResponse().ok());
|
||||
try {
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
|
||||
double distance = res.get(0).get(0).getDistance();
|
||||
if (tableName.startsWith("L2")) {
|
||||
Assert.assertEquals(distance, 0.0, epsilon);
|
||||
}else if (tableName.startsWith("IP")) {
|
||||
Assert.assertEquals(distance, 1.0, epsilon);
|
||||
}
|
||||
});
|
||||
}
|
||||
executor.awaitQuiescence(300, TimeUnit.SECONDS);
|
||||
executor.shutdown();
|
||||
Thread.sleep(2000);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
|
||||
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void test_create_insert_delete_threads(String host, String port) throws InterruptedException {
|
||||
int thread_num = 100;
|
||||
List<List<Float>> vectors = gen_vectors(nb,false);
|
||||
ForkJoinPool executor = new ForkJoinPool();
|
||||
for (int i = 0; i < thread_num; i++) {
|
||||
executor.execute(
|
||||
() -> {
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
String tableName = RandomStringUtils.randomAlphabetic(10);
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.IP)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
client.createTable(tableSchemaParam);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response response = client.dropTable(tableParam);
|
||||
Assert.assertTrue(response.ok());
|
||||
try {
|
||||
client.disconnect();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
});
|
||||
}
|
||||
executor.awaitQuiescence(100, TimeUnit.SECONDS);
|
||||
executor.shutdown();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.ConnectParam;
|
||||
import io.milvus.client.MilvusClient;
|
||||
import io.milvus.client.MilvusGrpcClient;
|
||||
import io.milvus.client.Response;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class TestPing {
|
||||
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
|
||||
public void test_server_status(String host, String port){
|
||||
System.out.println("Host: "+host+", Port: "+port);
|
||||
MilvusClient client = new MilvusGrpcClient();
|
||||
ConnectParam connectParam = new ConnectParam.Builder()
|
||||
.withHost(host)
|
||||
.withPort(port)
|
||||
.build();
|
||||
client.connect(connectParam);
|
||||
Response res = client.serverStatus();
|
||||
assert (res.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_server_status_without_connected(MilvusGrpcClient client, String tableName){
|
||||
Response res = client.serverStatus();
|
||||
assert (!res.ok());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,480 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class TestSearchVectors {
|
||||
int index_file_size = 10;
|
||||
int dimension = 128;
|
||||
int n_list = 1024;
|
||||
int default_n_list = 16384;
|
||||
int nb = 100000;
|
||||
int n_probe = 20;
|
||||
int top_k = 10;
|
||||
double epsilon = 0.001;
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
IndexType defaultIndexType = IndexType.FLAT;
|
||||
|
||||
|
||||
public List<Float> normalize(List<Float> w2v){
|
||||
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
|
||||
final float norm = (float) Math.sqrt(squareSum);
|
||||
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
|
||||
return w2v;
|
||||
}
|
||||
|
||||
public List<List<Float>> gen_vectors(int nb, boolean norm) {
|
||||
List<List<Float>> xb = new ArrayList<>();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
if (norm == true) {
|
||||
vector = normalize(vector);
|
||||
}
|
||||
xb.add(vector);
|
||||
}
|
||||
return xb;
|
||||
}
|
||||
|
||||
public static Date getDeltaDate(int delta) {
|
||||
Date today = new Date();
|
||||
Calendar c = Calendar.getInstance();
|
||||
c.setTime(today);
|
||||
c.add(Calendar.DAY_OF_MONTH, delta);
|
||||
return c.getTime();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
String tableNameNew = tableName + "_";
|
||||
int nq = 5;
|
||||
int nb = 100;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableNameNew, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVFLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.size(), nq);
|
||||
Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_ids_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVFLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, true);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<Long> vectorIds;
|
||||
vectorIds = Stream.iterate(0L, n -> n)
|
||||
.limit(nb)
|
||||
.collect(Collectors.toList());
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVFLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.size(), nq);
|
||||
Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_distance_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVFLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, true);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
double distance = res_search.get(0).get(0).getDistance();
|
||||
if (tableName.startsWith("L2")) {
|
||||
Assert.assertEquals(distance, 0.0, epsilon);
|
||||
}else if (tableName.startsWith("IP")) {
|
||||
Assert.assertEquals(distance, 1.0, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.size(), nq);
|
||||
Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.size(), nq);
|
||||
Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_distance_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
int nb = 1000;
|
||||
List<List<Float>> vectors = gen_vectors(nb, true);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(default_n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<Double>> res_search = client.search(searchParam).getResultDistancesList();
|
||||
for (int i = 0; i < nq; i++) {
|
||||
double distance = res_search.get(i).get(0);
|
||||
System.out.println(distance);
|
||||
if (tableName.startsWith("L2")) {
|
||||
Assert.assertEquals(distance, 0.0, epsilon);
|
||||
}else if (tableName.startsWith("IP")) {
|
||||
Assert.assertEquals(distance, 1.0, epsilon);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_index_FLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.FLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.size(), nq);
|
||||
Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_FLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.FLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res_search.size(), nq);
|
||||
Assert.assertEquals(res_search.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.FLAT;
|
||||
int nb = 100000;
|
||||
int nq = 1000;
|
||||
int top_k = 2048;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withTimeout(1).build();
|
||||
System.out.println(new Date());
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_FLAT_big_data_size(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.FLAT;
|
||||
int nb = 100000;
|
||||
int nq = 2000;
|
||||
int top_k = 2048;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
System.out.println(new Date());
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_distance_FLAT(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.FLAT;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, true);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
|
||||
List<List<SearchResponse.QueryResult>> res_search = client.search(searchParam).getQueryResultsList();
|
||||
double distance = res_search.get(0).get(0).getDistance();
|
||||
if (tableName.startsWith("L2")) {
|
||||
Assert.assertEquals(distance, 0.0, epsilon);
|
||||
}else if (tableName.startsWith("IP")) {
|
||||
Assert.assertEquals(distance, 1.0, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_invalid_n_probe(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
int n_probe_new = 0;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe_new).withTopK(top_k).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_invalid_top_k(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
int top_k_new = 0;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k_new).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
// public void test_search_invalid_query_vectors(MilvusClient client, String tableName) throws InterruptedException {
|
||||
// IndexType indexType = IndexType.IVF_SQ8;
|
||||
// int nq = 5;
|
||||
// List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
// client.insert(insertParam);
|
||||
// Index index = new Index.Builder().withIndexType(indexType)
|
||||
// .withNList(n_list)
|
||||
// .build();
|
||||
// CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
// client.createIndex(createIndexParam);
|
||||
// TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
// SearchParam searchParam = new SearchParam.Builder(tableName, null).withNProbe(n_probe).withTopK(top_k).build();
|
||||
// SearchResponse res_search = client.search(searchParam);
|
||||
// assert (!res_search.getResponse().ok());
|
||||
// }
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_index_range(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<DateRange> dateRange = new ArrayList<>();
|
||||
dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1)));
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res.size(), nq);
|
||||
Assert.assertEquals(res.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_range(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<DateRange> dateRange = new ArrayList<>();
|
||||
dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1)));
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res.size(), nq);
|
||||
Assert.assertEquals(res.get(0).size(), top_k);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_index_range_no_result(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<DateRange> dateRange = new ArrayList<>();
|
||||
dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1)));
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res.size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_range_no_result(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<DateRange> dateRange = new ArrayList<>();
|
||||
dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1)));
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (res_search.getResponse().ok());
|
||||
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
|
||||
Assert.assertEquals(res.size(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_index_range_invalid(MilvusClient client, String tableName) throws InterruptedException {
|
||||
IndexType indexType = IndexType.IVF_SQ8;
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<DateRange> dateRange = new ArrayList<>();
|
||||
dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1)));
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Index index = new Index.Builder().withIndexType(indexType)
|
||||
.withNList(n_list)
|
||||
.build();
|
||||
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
|
||||
client.createIndex(createIndexParam);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_search_range_invalid(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nq = 5;
|
||||
List<List<Float>> vectors = gen_vectors(nb, false);
|
||||
List<List<Float>> queryVectors = vectors.subList(0,nq);
|
||||
List<DateRange> dateRange = new ArrayList<>();
|
||||
dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1)));
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
|
||||
client.insert(insertParam);
|
||||
Thread.sleep(1000);
|
||||
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
|
||||
SearchResponse res_search = client.search(searchParam);
|
||||
assert (!res_search.getResponse().ok());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
package com;
|
||||
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TestTable {
|
||||
int index_file_size = 50;
|
||||
int dimension = 128;
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_create_table(MilvusClient client, String tableName){
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.L2)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
Response res = client.createTable(tableSchemaParam);
|
||||
assert(res.ok());
|
||||
Assert.assertEquals(res.ok(), true);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_create_table_disconnect(MilvusClient client, String tableName){
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.L2)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
Response res = client.createTable(tableSchemaParam);
|
||||
assert(!res.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_create_table_repeatably(MilvusClient client, String tableName){
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.L2)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
Response res = client.createTable(tableSchemaParam);
|
||||
Assert.assertEquals(res.ok(), true);
|
||||
Response res_new = client.createTable(tableSchemaParam);
|
||||
Assert.assertEquals(res_new.ok(), false);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_create_table_wrong_params(MilvusClient client, String tableName){
|
||||
Integer dimension = 0;
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.L2)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
Response res = client.createTable(tableSchemaParam);
|
||||
System.out.println(res.toString());
|
||||
Assert.assertEquals(res.ok(), false);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_show_tables(MilvusClient client, String tableName){
|
||||
Integer tableNum = 10;
|
||||
ShowTablesResponse res = null;
|
||||
for (int i = 0; i < tableNum; ++i) {
|
||||
String tableNameNew = tableName+"_"+Integer.toString(i);
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.L2)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
client.createTable(tableSchemaParam);
|
||||
List<String> tableNames = client.showTables().getTableNames();
|
||||
Assert.assertTrue(tableNames.contains(tableNameNew));
|
||||
}
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_show_tables_without_connect(MilvusClient client, String tableName){
|
||||
ShowTablesResponse res = client.showTables();
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_drop_table(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response res = client.dropTable(tableParam);
|
||||
assert(res.ok());
|
||||
Thread.currentThread().sleep(1000);
|
||||
List<String> tableNames = client.showTables().getTableNames();
|
||||
Assert.assertFalse(tableNames.contains(tableName));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_drop_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
|
||||
Response res = client.dropTable(tableParam);
|
||||
assert(!res.ok());
|
||||
List<String> tableNames = client.showTables().getTableNames();
|
||||
Assert.assertTrue(tableNames.contains(tableName));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_drop_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Response res = client.dropTable(tableParam);
|
||||
assert(!res.ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_describe_table(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeTableResponse res = client.describeTable(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
TableSchema tableSchema = res.getTableSchema().get();
|
||||
Assert.assertEquals(tableSchema.getDimension(), dimension);
|
||||
Assert.assertEquals(tableSchema.getTableName(), tableName);
|
||||
Assert.assertEquals(tableSchema.getIndexFileSize(), index_file_size);
|
||||
Assert.assertEquals(tableSchema.getMetricType().name(), tableName.substring(0,2));
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_describe_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
DescribeTableResponse res = client.describeTable(tableParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_has_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
|
||||
HasTableResponse res = client.hasTable(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Assert.assertFalse(res.hasTable());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_has_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
HasTableResponse res = client.hasTable(tableParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_has_table(MilvusClient client, String tableName) throws InterruptedException {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
HasTableResponse res = client.hasTable(tableParam);
|
||||
assert(res.getResponse().ok());
|
||||
Assert.assertTrue(res.hasTable());
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
package com;
|
||||
|
||||
import io.milvus.client.*;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public class TestTableCount {
|
||||
int index_file_size = 50;
|
||||
int dimension = 128;
|
||||
|
||||
public List<List<Float>> gen_vectors(Integer nb) {
|
||||
List<List<Float>> xb = new ArrayList<>();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
ArrayList<Float> vector = new ArrayList<>();
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
vector.add(random.nextFloat());
|
||||
}
|
||||
xb.add(vector);
|
||||
}
|
||||
return xb;
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_table_count_no_vectors(MilvusClient client, String tableName) {
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_table_count_table_not_existed(MilvusClient client, String tableName) {
|
||||
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
|
||||
GetTableRowCountResponse res = client.getTableRowCount(tableParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
|
||||
public void test_table_count_without_connect(MilvusClient client, String tableName) {
|
||||
TableParam tableParam = new TableParam.Builder(tableName+"_").build();
|
||||
GetTableRowCountResponse res = client.getTableRowCount(tableParam);
|
||||
assert(!res.getResponse().ok());
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_table_count(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
// Add vectors
|
||||
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();;
|
||||
client.insert(insertParam);
|
||||
Thread.currentThread().sleep(1000);
|
||||
TableParam tableParam = new TableParam.Builder(tableName).build();
|
||||
Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
|
||||
public void test_table_count_multi_tables(MilvusClient client, String tableName) throws InterruptedException {
|
||||
int nb = 10000;
|
||||
List<List<Float>> vectors = gen_vectors(nb);
|
||||
Integer tableNum = 10;
|
||||
GetTableRowCountResponse res = null;
|
||||
for (int i = 0; i < tableNum; ++i) {
|
||||
String tableNameNew = tableName + "_" + Integer.toString(i);
|
||||
TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension)
|
||||
.withIndexFileSize(index_file_size)
|
||||
.withMetricType(MetricType.L2)
|
||||
.build();
|
||||
TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
|
||||
client.createTable(tableSchemaParam);
|
||||
// Add vectors
|
||||
InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
|
||||
client.insert(insertParam);
|
||||
}
|
||||
Thread.currentThread().sleep(1000);
|
||||
for (int i = 0; i < tableNum; ++i) {
|
||||
String tableNameNew = tableName + "_" + Integer.toString(i);
|
||||
TableParam tableParam = new TableParam.Builder(tableNameNew).build();
|
||||
res = client.getTableRowCount(tableParam);
|
||||
Assert.assertEquals(res.getTableRowCount(), nb);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
<suite name="Test-class Suite">
|
||||
<test name="Test-class test" >
|
||||
<classes>
|
||||
<class name="com.TestConnect" />
|
||||
<class name="com.TestMix" />
|
||||
</classes>
|
||||
</test>
|
||||
</suite>
|
|
@ -0,0 +1,2 @@
|
|||
__pycache__/
|
||||
logs/
|
|
@ -0,0 +1,149 @@
|
|||
import pdb
|
||||
import random
|
||||
import logging
|
||||
import json
|
||||
import time, datetime
|
||||
from multiprocessing import Process
|
||||
import numpy
|
||||
import sklearn.preprocessing
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
|
||||
logger = logging.getLogger("milvus_ann_acc.client")
|
||||
|
||||
SERVER_HOST_DEFAULT = "127.0.0.1"
|
||||
SERVER_PORT_DEFAULT = 19530
|
||||
|
||||
|
||||
def time_wrapper(func):
|
||||
"""
|
||||
This decorator prints the execution time for the decorated function.
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class MilvusClient(object):
|
||||
def __init__(self, table_name=None, ip=None, port=None):
|
||||
self._milvus = Milvus()
|
||||
self._table_name = table_name
|
||||
try:
|
||||
if not ip:
|
||||
self._milvus.connect(
|
||||
host = SERVER_HOST_DEFAULT,
|
||||
port = SERVER_PORT_DEFAULT)
|
||||
else:
|
||||
self._milvus.connect(
|
||||
host = ip,
|
||||
port = port)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def __str__(self):
|
||||
return 'Milvus table %s' % self._table_name
|
||||
|
||||
def check_status(self, status):
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
||||
raise Exception("Status not ok")
|
||||
|
||||
def create_table(self, table_name, dimension, index_file_size, metric_type):
|
||||
if not self._table_name:
|
||||
self._table_name = table_name
|
||||
if metric_type == "l2":
|
||||
metric_type = MetricType.L2
|
||||
elif metric_type == "ip":
|
||||
metric_type = MetricType.IP
|
||||
else:
|
||||
logger.error("Not supported metric_type: %s" % metric_type)
|
||||
self._metric_type = metric_type
|
||||
create_param = {'table_name': table_name,
|
||||
'dimension': dimension,
|
||||
'index_file_size': index_file_size,
|
||||
"metric_type": metric_type}
|
||||
status = self._milvus.create_table(create_param)
|
||||
self.check_status(status)
|
||||
|
||||
@time_wrapper
|
||||
def insert(self, X, ids):
|
||||
if self._metric_type == MetricType.IP:
|
||||
logger.info("Set normalize for metric_type: Inner Product")
|
||||
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
|
||||
X = X.astype(numpy.float32)
|
||||
status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids)
|
||||
self.check_status(status)
|
||||
return status, result
|
||||
|
||||
@time_wrapper
|
||||
def create_index(self, index_type, nlist):
|
||||
if index_type == "flat":
|
||||
index_type = IndexType.FLAT
|
||||
elif index_type == "ivf_flat":
|
||||
index_type = IndexType.IVFLAT
|
||||
elif index_type == "ivf_sq8":
|
||||
index_type = IndexType.IVF_SQ8
|
||||
elif index_type == "ivf_sq8h":
|
||||
index_type = IndexType.IVF_SQ8H
|
||||
elif index_type == "mix_nsg":
|
||||
index_type = IndexType.MIX_NSG
|
||||
index_params = {
|
||||
"index_type": index_type,
|
||||
"nlist": nlist,
|
||||
}
|
||||
logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
|
||||
status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600)
|
||||
self.check_status(status)
|
||||
|
||||
def describe_index(self):
|
||||
return self._milvus.describe_index(self._table_name)
|
||||
|
||||
def drop_index(self):
|
||||
logger.info("Drop index: %s" % self._table_name)
|
||||
return self._milvus.drop_index(self._table_name)
|
||||
|
||||
@time_wrapper
|
||||
def query(self, X, top_k, nprobe):
|
||||
if self._metric_type == MetricType.IP:
|
||||
logger.info("Set normalize for metric_type: Inner Product")
|
||||
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
|
||||
X = X.astype(numpy.float32)
|
||||
status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist())
|
||||
self.check_status(status)
|
||||
# logger.info(results[0])
|
||||
ids = []
|
||||
for result in results:
|
||||
tmp_ids = []
|
||||
for item in result:
|
||||
tmp_ids.append(item.id)
|
||||
ids.append(tmp_ids)
|
||||
return ids
|
||||
|
||||
def count(self):
|
||||
return self._milvus.get_table_row_count(self._table_name)[1]
|
||||
|
||||
def delete(self, timeout=60):
|
||||
logger.info("Start delete table: %s" % self._table_name)
|
||||
self._milvus.delete_table(self._table_name)
|
||||
i = 0
|
||||
while i < timeout:
|
||||
if self.count():
|
||||
time.sleep(1)
|
||||
i = i + 1
|
||||
else:
|
||||
break
|
||||
if i >= timeout:
|
||||
logger.error("Delete table timeout")
|
||||
|
||||
def describe(self):
|
||||
return self._milvus.describe_table(self._table_name)
|
||||
|
||||
def exists_table(self):
|
||||
return self._milvus.has_table(self._table_name)
|
||||
|
||||
@time_wrapper
|
||||
def preload_table(self):
|
||||
return self._milvus.preload_table(self._table_name)
|
|
@ -0,0 +1,17 @@
|
|||
datasets:
|
||||
sift-128-euclidean:
|
||||
cpu_cache_size: 16
|
||||
gpu_cache_size: 5
|
||||
index_file_size: [1024]
|
||||
nytimes-16-angular:
|
||||
cpu_cache_size: 16
|
||||
gpu_cache_size: 5
|
||||
index_file_size: [1024]
|
||||
|
||||
index:
|
||||
index_types: ['flat', 'ivf_flat', 'ivf_sq8']
|
||||
nlists: [8092, 16384]
|
||||
|
||||
search:
|
||||
nprobes: [1, 8, 32]
|
||||
top_ks: [10]
|
|
@ -0,0 +1,26 @@
|
|||
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
metavar='NAME',
|
||||
help='the dataset to load training points from',
|
||||
default='glove-100-angular',
|
||||
choices=DATASETS.keys())
|
||||
parser.add_argument(
|
||||
"-k", "--count",
|
||||
default=10,
|
||||
type=positive_int,
|
||||
help="the number of near neighbours to search for")
|
||||
parser.add_argument(
|
||||
'--definitions',
|
||||
metavar='FILE',
|
||||
help='load algorithm definitions from FILE',
|
||||
default='algos.yaml')
|
||||
parser.add_argument(
|
||||
'--image-tag',
|
||||
default=None,
|
||||
help='pull image first')
|
|
@ -0,0 +1,132 @@
|
|||
import os
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
import h5py
|
||||
import numpy
|
||||
import logging
|
||||
from logging import handlers
|
||||
|
||||
from client import MilvusClient
|
||||
|
||||
LOG_FOLDER = "logs"
|
||||
logger = logging.getLogger("milvus_ann_acc")
|
||||
|
||||
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
|
||||
if not os.path.exists(LOG_FOLDER):
|
||||
os.system('mkdir -p %s' % LOG_FOLDER)
|
||||
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10)
|
||||
fileTimeHandler.suffix = "%Y%m%d.log"
|
||||
fileTimeHandler.setFormatter(formatter)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
fileTimeHandler.setFormatter(formatter)
|
||||
logger.addHandler(fileTimeHandler)
|
||||
|
||||
|
||||
def get_dataset_fn(dataset_name):
|
||||
file_path = "/test/milvus/ann_hdf5/"
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception("%s not exists" % file_path)
|
||||
return os.path.join(file_path, '%s.hdf5' % dataset_name)
|
||||
|
||||
|
||||
def get_dataset(dataset_name):
|
||||
hdf5_fn = get_dataset_fn(dataset_name)
|
||||
hdf5_f = h5py.File(hdf5_fn)
|
||||
return hdf5_f
|
||||
|
||||
|
||||
def parse_dataset_name(dataset_name):
|
||||
data_type = dataset_name.split("-")[0]
|
||||
dimension = int(dataset_name.split("-")[1])
|
||||
metric = dataset_name.split("-")[-1]
|
||||
# metric = dataset.attrs['distance']
|
||||
# dimension = len(dataset["train"][0])
|
||||
if metric == "euclidean":
|
||||
metric_type = "l2"
|
||||
elif metric == "angular":
|
||||
metric_type = "ip"
|
||||
return ("ann"+data_type, dimension, metric_type)
|
||||
|
||||
|
||||
def get_table_name(dataset_name, index_file_size):
|
||||
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
|
||||
dataset = get_dataset(dataset_name)
|
||||
table_size = len(dataset["train"])
|
||||
table_size = str(table_size // 1000000)+"m"
|
||||
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
|
||||
return table_name
|
||||
|
||||
|
||||
def main(dataset_name, index_file_size, nlist=16384, force=False):
|
||||
top_k = 10
|
||||
nprobes = [32, 128]
|
||||
|
||||
dataset = get_dataset(dataset_name)
|
||||
table_name = get_table_name(dataset_name, index_file_size)
|
||||
m = MilvusClient(table_name)
|
||||
if m.exists_table():
|
||||
if force is True:
|
||||
logger.info("Re-create table: %s" % table_name)
|
||||
m.delete()
|
||||
time.sleep(10)
|
||||
else:
|
||||
logger.info("Table name: %s existed" % table_name)
|
||||
return
|
||||
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
|
||||
m.create_table(table_name, dimension, index_file_size, metric_type)
|
||||
print(m.describe())
|
||||
vectors = numpy.array(dataset["train"])
|
||||
query_vectors = numpy.array(dataset["test"])
|
||||
# m.insert(vectors)
|
||||
|
||||
interval = 100000
|
||||
loops = len(vectors) // interval + 1
|
||||
|
||||
for i in range(loops):
|
||||
start = i*interval
|
||||
end = min((i+1)*interval, len(vectors))
|
||||
tmp_vectors = vectors[start:end]
|
||||
if start < end:
|
||||
m.insert(tmp_vectors, ids=[i for i in range(start, end)])
|
||||
|
||||
time.sleep(60)
|
||||
print(m.count())
|
||||
|
||||
for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]:
|
||||
m.create_index(index_type, nlist)
|
||||
print(m.describe_index())
|
||||
if m.count() != len(vectors):
|
||||
return
|
||||
m.preload_table()
|
||||
true_ids = numpy.array(dataset["neighbors"])
|
||||
for nprobe in nprobes:
|
||||
print("nprobe: %s" % nprobe)
|
||||
sum_radio = 0.0; avg_radio = 0.0
|
||||
result_ids = m.query(query_vectors, top_k, nprobe)
|
||||
# print(result_ids[:10])
|
||||
for index, result_item in enumerate(result_ids):
|
||||
if len(set(true_ids[index][:top_k])) != len(set(result_item)):
|
||||
logger.info("Error happened")
|
||||
# logger.info(query_vectors[index])
|
||||
# logger.info(true_ids[index][:top_k], result_item)
|
||||
tmp = set(true_ids[index][:top_k]).intersection(set(result_item))
|
||||
sum_radio = sum_radio + (len(tmp) / top_k)
|
||||
avg_radio = round(sum_radio / len(result_ids), 4)
|
||||
logger.info(avg_radio)
|
||||
m.drop_index()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("glove-25-angular")
|
||||
# main("sift-128-euclidean", 1024, force=True)
|
||||
for index_file_size in [50, 1024]:
|
||||
print("Index file size: %d" % index_file_size)
|
||||
main("glove-25-angular", index_file_size, force=True)
|
||||
|
||||
print("sift-128-euclidean")
|
||||
for index_file_size in [50, 1024]:
|
||||
print("Index file size: %d" % index_file_size)
|
||||
main("sift-128-euclidean", index_file_size, force=True)
|
||||
# m = MilvusClient()
|
|
@ -0,0 +1,8 @@
|
|||
random_data
|
||||
benchmark_logs/
|
||||
db/
|
||||
logs/
|
||||
*idmap*.txt
|
||||
__pycache__/
|
||||
venv
|
||||
.idea
|
|
@ -0,0 +1,57 @@
|
|||
# Quick start
|
||||
|
||||
## 运行
|
||||
|
||||
### 运行示例:
|
||||
|
||||
`python3 main.py --image=registry.zilliz.com/milvus/engine:branch-0.3.1-release --run-count=2 --run-type=performance`
|
||||
|
||||
### 运行参数:
|
||||
|
||||
--image: 容器模式,传入镜像名称,如传入,则运行测试时,会先进行pull image,基于image生成milvus server容器
|
||||
|
||||
--local: 与image参数互斥,本地模式,连接使用本地启动的milvus server进行测试
|
||||
|
||||
--run-count: 重复运行次数
|
||||
|
||||
--suites: 测试集配置文件,默认使用suites.yaml
|
||||
|
||||
--run-type: 测试类型,包括性能--performance、准确性测试--accuracy以及稳定性--stability
|
||||
|
||||
### 测试集配置文件:
|
||||
|
||||
`operations:
|
||||
|
||||
insert:
|
||||
|
||||
[
|
||||
{"table.index_type": "ivf_flat", "server.index_building_threshold": 300, "table.size": 2000000, "table.ni": 100000, "table.dim": 512},
|
||||
]
|
||||
|
||||
build: []
|
||||
|
||||
query:
|
||||
|
||||
[
|
||||
{"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 1, "server.use_blas_threshold": 800},
|
||||
{"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 10, "server.use_blas_threshold": 20},
|
||||
]`
|
||||
|
||||
## 测试结果:
|
||||
|
||||
性能:
|
||||
|
||||
`INFO:milvus_benchmark.runner:Start warm query, query params: top-k: 1, nq: 1
|
||||
|
||||
INFO:milvus_benchmark.client:query run in 19.19s
|
||||
INFO:milvus_benchmark.runner:Start query, query params: top-k: 64, nq: 10, actually length of vectors: 10
|
||||
INFO:milvus_benchmark.runner:Start run query, run 1 of 1
|
||||
INFO:milvus_benchmark.client:query run in 0.2s
|
||||
INFO:milvus_benchmark.runner:Avarage query time: 0.20
|
||||
INFO:milvus_benchmark.runner:[[0.2]]`
|
||||
|
||||
**│ 10 │ 0.2 │**
|
||||
|
||||
准确率:
|
||||
|
||||
`INFO:milvus_benchmark.runner:Avarage accuracy: 1.0`
|
|
@ -0,0 +1,244 @@
|
|||
import pdb
|
||||
import random
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
import time, datetime
|
||||
from multiprocessing import Process
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.client")
|
||||
|
||||
SERVER_HOST_DEFAULT = "127.0.0.1"
|
||||
SERVER_PORT_DEFAULT = 19530
|
||||
|
||||
|
||||
def time_wrapper(func):
|
||||
"""
|
||||
This decorator prints the execution time for the decorated function.
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class MilvusClient(object):
|
||||
def __init__(self, table_name=None, ip=None, port=None):
|
||||
self._milvus = Milvus()
|
||||
self._table_name = table_name
|
||||
try:
|
||||
if not ip:
|
||||
self._milvus.connect(
|
||||
host = SERVER_HOST_DEFAULT,
|
||||
port = SERVER_PORT_DEFAULT)
|
||||
else:
|
||||
self._milvus.connect(
|
||||
host = ip,
|
||||
port = port)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def __str__(self):
|
||||
return 'Milvus table %s' % self._table_name
|
||||
|
||||
def check_status(self, status):
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
||||
raise Exception("Status not ok")
|
||||
|
||||
def create_table(self, table_name, dimension, index_file_size, metric_type):
|
||||
if not self._table_name:
|
||||
self._table_name = table_name
|
||||
if metric_type == "l2":
|
||||
metric_type = MetricType.L2
|
||||
elif metric_type == "ip":
|
||||
metric_type = MetricType.IP
|
||||
else:
|
||||
logger.error("Not supported metric_type: %s" % metric_type)
|
||||
create_param = {'table_name': table_name,
|
||||
'dimension': dimension,
|
||||
'index_file_size': index_file_size,
|
||||
"metric_type": metric_type}
|
||||
status = self._milvus.create_table(create_param)
|
||||
self.check_status(status)
|
||||
|
||||
@time_wrapper
|
||||
def insert(self, X, ids=None):
|
||||
status, result = self._milvus.add_vectors(self._table_name, X, ids)
|
||||
self.check_status(status)
|
||||
return status, result
|
||||
|
||||
@time_wrapper
|
||||
def create_index(self, index_type, nlist):
|
||||
if index_type == "flat":
|
||||
index_type = IndexType.FLAT
|
||||
elif index_type == "ivf_flat":
|
||||
index_type = IndexType.IVFLAT
|
||||
elif index_type == "ivf_sq8":
|
||||
index_type = IndexType.IVF_SQ8
|
||||
elif index_type == "mix_nsg":
|
||||
index_type = IndexType.MIX_NSG
|
||||
elif index_type == "ivf_sq8h":
|
||||
index_type = IndexType.IVF_SQ8H
|
||||
index_params = {
|
||||
"index_type": index_type,
|
||||
"nlist": nlist,
|
||||
}
|
||||
logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
|
||||
status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600)
|
||||
self.check_status(status)
|
||||
|
||||
def describe_index(self):
|
||||
return self._milvus.describe_index(self._table_name)
|
||||
|
||||
def drop_index(self):
|
||||
logger.info("Drop index: %s" % self._table_name)
|
||||
return self._milvus.drop_index(self._table_name)
|
||||
|
||||
@time_wrapper
|
||||
def query(self, X, top_k, nprobe):
|
||||
status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X)
|
||||
self.check_status(status)
|
||||
return status, result
|
||||
|
||||
def count(self):
|
||||
return self._milvus.get_table_row_count(self._table_name)[1]
|
||||
|
||||
def delete(self, timeout=60):
|
||||
logger.info("Start delete table: %s" % self._table_name)
|
||||
self._milvus.delete_table(self._table_name)
|
||||
i = 0
|
||||
while i < timeout:
|
||||
if self.count():
|
||||
time.sleep(1)
|
||||
i = i + 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if i < timeout:
|
||||
logger.error("Delete table timeout")
|
||||
|
||||
def describe(self):
|
||||
return self._milvus.describe_table(self._table_name)
|
||||
|
||||
def exists_table(self):
|
||||
return self._milvus.has_table(self._table_name)
|
||||
|
||||
@time_wrapper
|
||||
def preload_table(self):
|
||||
return self._milvus.preload_table(self._table_name, timeout=3000)
|
||||
|
||||
|
||||
def fit(table_name, X):
|
||||
milvus = Milvus()
|
||||
milvus.connect(host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT)
|
||||
start = time.time()
|
||||
status, ids = milvus.add_vectors(table_name, X)
|
||||
end = time.time()
|
||||
logger(status, round(end - start, 2))
|
||||
|
||||
|
||||
def fit_concurrent(table_name, process_num, vectors):
|
||||
processes = []
|
||||
|
||||
for i in range(process_num):
|
||||
p = Process(target=fit, args=(table_name, vectors, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# table_name = "sift_2m_20_128_l2"
|
||||
table_name = "test_tset1"
|
||||
m = MilvusClient(table_name)
|
||||
# m.create_table(table_name, 128, 50, "l2")
|
||||
|
||||
print(m.describe())
|
||||
# print(m.count())
|
||||
# print(m.describe_index())
|
||||
insert_vectors = [[random.random() for _ in range(128)] for _ in range(10000)]
|
||||
for i in range(5):
|
||||
m.insert(insert_vectors)
|
||||
print(m.create_index("ivf_sq8h", 16384))
|
||||
X = [insert_vectors[0]]
|
||||
top_k = 10
|
||||
nprobe = 10
|
||||
print(m.query(X, top_k, nprobe))
|
||||
|
||||
# # # print(m.drop_index())
|
||||
# # print(m.describe_index())
|
||||
# # sys.exit()
|
||||
# # # insert_vectors = [[random.random() for _ in range(128)] for _ in range(100000)]
|
||||
# # # for i in range(100):
|
||||
# # # m.insert(insert_vectors)
|
||||
# # # time.sleep(5)
|
||||
# # # print(m.describe_index())
|
||||
# # # print(m.drop_index())
|
||||
# # m.create_index("ivf_sq8h", 16384)
|
||||
# print(m.count())
|
||||
# print(m.describe_index())
|
||||
|
||||
|
||||
|
||||
# sys.exit()
|
||||
# print(m.create_index("ivf_sq8h", 16384))
|
||||
# print(m.count())
|
||||
# print(m.describe_index())
|
||||
import numpy as np
|
||||
|
||||
def mmap_fvecs(fname):
|
||||
x = np.memmap(fname, dtype='int32', mode='r')
|
||||
d = x[0]
|
||||
return x.view('float32').reshape(-1, d + 1)[:, 1:]
|
||||
|
||||
print(mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs"))
|
||||
# SIFT_SRC_QUERY_DATA_DIR = '/poc/yuncong/ann_1000m'
|
||||
# file_name = SIFT_SRC_QUERY_DATA_DIR+'/'+'query.npy'
|
||||
# data = numpy.load(file_name)
|
||||
# query_vectors = data[0:2].tolist()
|
||||
# print(len(query_vectors))
|
||||
# results = m.query(query_vectors, 10, 10)
|
||||
# result_ids = []
|
||||
# for result in results[1]:
|
||||
# tmp = []
|
||||
# for item in result:
|
||||
# tmp.append(item.id)
|
||||
# result_ids.append(tmp)
|
||||
# print(result_ids[0][:10])
|
||||
# # gt
|
||||
# file_name = SIFT_SRC_QUERY_DATA_DIR+"/gnd/"+"idx_1M.ivecs"
|
||||
# a = numpy.fromfile(file_name, dtype='int32')
|
||||
# d = a[0]
|
||||
# true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
|
||||
# print(true_ids[:3, :2])
|
||||
|
||||
# print(len(true_ids[0]))
|
||||
# import numpy as np
|
||||
# import sklearn.preprocessing
|
||||
|
||||
# def mmap_fvecs(fname):
|
||||
# x = np.memmap(fname, dtype='int32', mode='r')
|
||||
# d = x[0]
|
||||
# return x.view('float32').reshape(-1, d + 1)[:, 1:]
|
||||
|
||||
# data = mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs")
|
||||
# print(data[0], len(data[0]), len(data))
|
||||
|
||||
# total_size = 10000
|
||||
# # total_size = 1000000000
|
||||
# file_size = 1000
|
||||
# # file_size = 100000
|
||||
# file_num = total_size // file_size
|
||||
# for i in range(file_num):
|
||||
# fname = "/test/milvus/raw_data/deep1b/binary_96_%05d" % i
|
||||
# print(fname, i*file_size, (i+1)*file_size)
|
||||
# single_data = data[i*file_size : (i+1)*file_size]
|
||||
# single_data = sklearn.preprocessing.normalize(single_data, axis=1, norm='l2')
|
||||
# np.save(fname, single_data)
|
|
@ -0,0 +1,28 @@
|
|||
* GLOBAL:
|
||||
FORMAT = "%datetime | %level | %logger | %msg"
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log"
|
||||
ENABLED = true
|
||||
TO_FILE = true
|
||||
TO_STANDARD_OUTPUT = false
|
||||
SUBSECOND_PRECISION = 3
|
||||
PERFORMANCE_TRACKING = false
|
||||
MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB
|
||||
* DEBUG:
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log"
|
||||
ENABLED = true
|
||||
* WARNING:
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log"
|
||||
* TRACE:
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log"
|
||||
* VERBOSE:
|
||||
FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
|
||||
TO_FILE = false
|
||||
TO_STANDARD_OUTPUT = false
|
||||
## Error logs
|
||||
* ERROR:
|
||||
ENABLED = true
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log"
|
||||
* FATAL:
|
||||
ENABLED = true
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log"
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
cache_config:
|
||||
cache_insert_data: false
|
||||
cpu_cache_capacity: 16
|
||||
gpu_cache_capacity: 6
|
||||
cpu_cache_threshold: 0.85
|
||||
db_config:
|
||||
backend_url: sqlite://:@:/
|
||||
build_index_gpu: 0
|
||||
insert_buffer_size: 4
|
||||
preload_table: null
|
||||
primary_path: /opt/milvus
|
||||
secondary_path: null
|
||||
engine_config:
|
||||
use_blas_threshold: 20
|
||||
metric_config:
|
||||
collector: prometheus
|
||||
enable_monitor: true
|
||||
prometheus_config:
|
||||
port: 8080
|
||||
resource_config:
|
||||
resource_pool:
|
||||
- cpu
|
||||
- gpu0
|
||||
server_config:
|
||||
address: 0.0.0.0
|
||||
deploy_mode: single
|
||||
port: 19530
|
||||
time_zone: UTC+8
|
|
@ -0,0 +1,31 @@
|
|||
server_config:
|
||||
address: 0.0.0.0
|
||||
port: 19530
|
||||
deploy_mode: single
|
||||
time_zone: UTC+8
|
||||
|
||||
db_config:
|
||||
primary_path: /opt/milvus
|
||||
secondary_path:
|
||||
backend_url: sqlite://:@:/
|
||||
insert_buffer_size: 4
|
||||
build_index_gpu: 0
|
||||
preload_table:
|
||||
|
||||
metric_config:
|
||||
enable_monitor: false
|
||||
collector: prometheus
|
||||
prometheus_config:
|
||||
port: 8080
|
||||
|
||||
cache_config:
|
||||
cpu_cache_capacity: 16
|
||||
cpu_cache_threshold: 0.85
|
||||
cache_insert_data: false
|
||||
|
||||
engine_config:
|
||||
use_blas_threshold: 20
|
||||
|
||||
resource_config:
|
||||
resource_pool:
|
||||
- cpu
|
|
@ -0,0 +1,33 @@
|
|||
server_config:
|
||||
address: 0.0.0.0
|
||||
port: 19530
|
||||
deploy_mode: single
|
||||
time_zone: UTC+8
|
||||
|
||||
db_config:
|
||||
primary_path: /opt/milvus
|
||||
secondary_path:
|
||||
backend_url: sqlite://:@:/
|
||||
insert_buffer_size: 4
|
||||
build_index_gpu: 0
|
||||
preload_table:
|
||||
|
||||
metric_config:
|
||||
enable_monitor: false
|
||||
collector: prometheus
|
||||
prometheus_config:
|
||||
port: 8080
|
||||
|
||||
cache_config:
|
||||
cpu_cache_capacity: 16
|
||||
cpu_cache_threshold: 0.85
|
||||
cache_insert_data: false
|
||||
|
||||
engine_config:
|
||||
use_blas_threshold: 20
|
||||
|
||||
resource_config:
|
||||
resource_pool:
|
||||
- cpu
|
||||
- gpu0
|
||||
- gpu1
|
|
@ -0,0 +1,32 @@
|
|||
server_config:
|
||||
address: 0.0.0.0
|
||||
port: 19530
|
||||
deploy_mode: single
|
||||
time_zone: UTC+8
|
||||
|
||||
db_config:
|
||||
primary_path: /opt/milvus
|
||||
secondary_path:
|
||||
backend_url: sqlite://:@:/
|
||||
insert_buffer_size: 4
|
||||
build_index_gpu: 0
|
||||
preload_table:
|
||||
|
||||
metric_config:
|
||||
enable_monitor: false
|
||||
collector: prometheus
|
||||
prometheus_config:
|
||||
port: 8080
|
||||
|
||||
cache_config:
|
||||
cpu_cache_capacity: 16
|
||||
cpu_cache_threshold: 0.85
|
||||
cache_insert_data: false
|
||||
|
||||
engine_config:
|
||||
use_blas_threshold: 20
|
||||
|
||||
resource_config:
|
||||
resource_pool:
|
||||
- cpu
|
||||
- gpu0
|
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from client import MilvusClient
|
||||
|
||||
nq = 100000
|
||||
dimension = 128
|
||||
run_count = 1
|
||||
table_name = "sift_10m_1024_128_ip"
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
|
||||
def do_query(milvus, table_name, top_ks, nqs, nprobe, run_count):
|
||||
bi_res = []
|
||||
for index, nq in enumerate(nqs):
|
||||
tmp_res = []
|
||||
for top_k in top_ks:
|
||||
avg_query_time = 0.0
|
||||
total_query_time = 0.0
|
||||
vectors = insert_vectors[0:nq]
|
||||
for i in range(run_count):
|
||||
start_time = time.time()
|
||||
status, query_res = milvus.query(vectors, top_k, nprobe)
|
||||
total_query_time = total_query_time + (time.time() - start_time)
|
||||
if status.code:
|
||||
print(status.message)
|
||||
avg_query_time = round(total_query_time / run_count, 2)
|
||||
tmp_res.append(avg_query_time)
|
||||
bi_res.append(tmp_res)
|
||||
return bi_res
|
||||
|
||||
while 1:
|
||||
milvus_instance = MilvusClient(table_name, ip="192.168.1.197", port=19530)
|
||||
top_ks = random.sample([x for x in range(1, 100)], 4)
|
||||
nqs = random.sample([x for x in range(1, 1000)], 3)
|
||||
nprobe = random.choice([x for x in range(1, 500)])
|
||||
res = do_query(milvus_instance, table_name, top_ks, nqs, nprobe, run_count)
|
||||
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
||||
|
||||
# status = milvus_instance.drop_index()
|
||||
if not status.OK():
|
||||
print(status.message)
|
||||
index_type = "ivf_sq8"
|
||||
status = milvus_instance.create_index(index_type, 16384)
|
||||
if not status.OK():
|
||||
print(status.message)
|
|
@ -0,0 +1,261 @@
|
|||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from client import MilvusClient
|
||||
import utils
|
||||
import parser
|
||||
from runner import Runner
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.docker")
|
||||
|
||||
|
||||
class DockerRunner(Runner):
|
||||
"""run docker mode"""
|
||||
def __init__(self, image):
|
||||
super(DockerRunner, self).__init__()
|
||||
self.image = image
|
||||
|
||||
def run(self, definition, run_type=None):
|
||||
if run_type == "performance":
|
||||
for op_type, op_value in definition.items():
|
||||
# run docker mode
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
|
||||
if op_type == "insert":
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
table_name = param["table_name"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
print(table_name)
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
# Update server config
|
||||
utils.modify_config(k, v, type="server", db_slave=None)
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(table_name)
|
||||
# Check has table or not
|
||||
if milvus.exists_table():
|
||||
milvus.delete()
|
||||
time.sleep(10)
|
||||
milvus.create_table(table_name, dimension, index_file_size, metric_type)
|
||||
res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
|
||||
logger.info(res)
|
||||
|
||||
# wait for file merge
|
||||
time.sleep(6 * (table_size / 500000))
|
||||
# Clear up
|
||||
utils.remove_container(container)
|
||||
|
||||
elif op_type == "query":
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
table_name = param["dataset"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(table_name)
|
||||
logger.debug(milvus._milvus.show_tables())
|
||||
# Check has table or not
|
||||
if not milvus.exists_table():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
|
||||
continue
|
||||
# parse index info
|
||||
index_types = param["index.index_types"]
|
||||
nlists = param["index.nlists"]
|
||||
# parse top-k, nq, nprobe
|
||||
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
||||
for index_type in index_types:
|
||||
for nlist in nlists:
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
milvus.create_index(index_type, nlist)
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
# preload index
|
||||
milvus.preload_table()
|
||||
logger.info("Start warm up query")
|
||||
res = self.do_query(milvus, table_name, [1], [1], 1, 1)
|
||||
logger.info("End warm up query")
|
||||
# Run query test
|
||||
for nprobe in nprobes:
|
||||
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
|
||||
headers = ["Nprobe/Top-k"]
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
utils.print_table(headers, nqs, res)
|
||||
utils.remove_container(container)
|
||||
|
||||
elif run_type == "accuracy":
|
||||
"""
|
||||
{
|
||||
"dataset": "random_50m_1024_512",
|
||||
"index.index_types": ["flat", ivf_flat", "ivf_sq8"],
|
||||
"index.nlists": [16384],
|
||||
"nprobes": [1, 32, 128],
|
||||
"nqs": [100],
|
||||
"top_ks": [1, 64],
|
||||
"server.use_blas_threshold": 1100,
|
||||
"server.cpu_cache_capacity": 256
|
||||
}
|
||||
"""
|
||||
for op_type, op_value in definition.items():
|
||||
if op_type != "query":
|
||||
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
||||
break
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
table_name = param["dataset"]
|
||||
sift_acc = False
|
||||
if "sift_acc" in param:
|
||||
sift_acc = param["sift_acc"]
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
volume_name = param["db_path_prefix"]
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(table_name)
|
||||
# Check has table or not
|
||||
if not milvus.exists_table():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
|
||||
continue
|
||||
|
||||
# parse index info
|
||||
index_types = param["index.index_types"]
|
||||
nlists = param["index.nlists"]
|
||||
# parse top-k, nq, nprobe
|
||||
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
||||
|
||||
if sift_acc is True:
|
||||
# preload groundtruth data
|
||||
true_ids_all = self.get_groundtruth_ids(table_size)
|
||||
|
||||
acc_dict = {}
|
||||
for index_type in index_types:
|
||||
for nlist in nlists:
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
milvus.create_index(index_type, nlist)
|
||||
# preload index
|
||||
milvus.preload_table()
|
||||
# Run query test
|
||||
for nprobe in nprobes:
|
||||
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
for top_k in top_ks:
|
||||
for nq in nqs:
|
||||
result_ids = []
|
||||
id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
|
||||
(table_name, index_type, nlist, metric_type, nprobe, top_k, nq)
|
||||
if sift_acc is False:
|
||||
self.do_query_acc(milvus, table_name, top_k, nq, nprobe, id_prefix)
|
||||
if index_type != "flat":
|
||||
# Compute accuracy
|
||||
base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
|
||||
(table_name, nlist, metric_type, nprobe, top_k, nq)
|
||||
avg_acc = self.compute_accuracy(base_name, id_prefix)
|
||||
logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc))
|
||||
else:
|
||||
result_ids = self.do_query_ids(milvus, table_name, top_k, nq, nprobe)
|
||||
acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
|
||||
logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value))
|
||||
# # print accuracy table
|
||||
# headers = [table_name]
|
||||
# headers.extend([str(top_k) for top_k in top_ks])
|
||||
# utils.print_table(headers, nqs, res)
|
||||
|
||||
# remove container, and run next definition
|
||||
logger.info("remove container, and run next definition")
|
||||
utils.remove_container(container)
|
||||
|
||||
elif run_type == "stability":
|
||||
for op_type, op_value in definition.items():
|
||||
if op_type != "query":
|
||||
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
||||
break
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
container = None
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
table_name = param["dataset"]
|
||||
volume_name = param["db_path_prefix"]
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
|
||||
# set default test time
|
||||
if "during_time" not in param:
|
||||
during_time = 100 # seconds
|
||||
else:
|
||||
during_time = int(param["during_time"]) * 60
|
||||
# set default query process num
|
||||
if "query_process_num" not in param:
|
||||
query_process_num = 10
|
||||
else:
|
||||
query_process_num = int(param["query_process_num"])
|
||||
|
||||
for k, v in param.items():
|
||||
if k.startswith("server."):
|
||||
utils.modify_config(k, v, type="server")
|
||||
|
||||
container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
|
||||
time.sleep(2)
|
||||
milvus = MilvusClient(table_name)
|
||||
# Check has table or not
|
||||
if not milvus.exists_table():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)]
|
||||
while time.time() < start_time + during_time:
|
||||
processes = []
|
||||
# do query
|
||||
# for i in range(query_process_num):
|
||||
# milvus_instance = MilvusClient(table_name)
|
||||
# top_k = random.choice([x for x in range(1, 100)])
|
||||
# nq = random.choice([x for x in range(1, 100)])
|
||||
# nprobe = random.choice([x for x in range(1, 1000)])
|
||||
# # logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
# p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], [nprobe], run_count, ))
|
||||
# processes.append(p)
|
||||
# p.start()
|
||||
# time.sleep(0.1)
|
||||
# for p in processes:
|
||||
# p.join()
|
||||
milvus_instance = MilvusClient(table_name)
|
||||
top_ks = random.sample([x for x in range(1, 100)], 3)
|
||||
nqs = random.sample([x for x in range(1, 1000)], 3)
|
||||
nprobe = random.choice([x for x in range(1, 500)])
|
||||
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
|
||||
if int(time.time() - start_time) % 120 == 0:
|
||||
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
|
||||
if not status.OK():
|
||||
logger.error(status)
|
||||
# status = milvus_instance.drop_index()
|
||||
# if not status.OK():
|
||||
# logger.error(status)
|
||||
# index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
|
||||
result = milvus.describe_index()
|
||||
logger.info(result)
|
||||
milvus_instance.create_index("ivf_sq8", 16384)
|
||||
utils.remove_container(container)
|
||||
|
||||
else:
|
||||
logger.warning("Run type: %s not supported" % run_type)
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from client import MilvusClient
|
||||
import utils
|
||||
import parser
|
||||
from runner import Runner
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.local_runner")
|
||||
|
||||
|
||||
class LocalRunner(Runner):
|
||||
"""run local mode"""
|
||||
def __init__(self, ip, port):
|
||||
super(LocalRunner, self).__init__()
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
|
||||
def run(self, definition, run_type=None):
|
||||
if run_type == "performance":
|
||||
for op_type, op_value in definition.items():
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
|
||||
if op_type == "insert":
|
||||
for index, param in enumerate(run_params):
|
||||
table_name = param["table_name"]
|
||||
# random_1m_100_512
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
|
||||
# Check has table or not
|
||||
if milvus.exists_table():
|
||||
milvus.delete()
|
||||
time.sleep(10)
|
||||
milvus.create_table(table_name, dimension, index_file_size, metric_type)
|
||||
res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
|
||||
logger.info(res)
|
||||
|
||||
elif op_type == "query":
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
table_name = param["dataset"]
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
|
||||
milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
|
||||
# parse index info
|
||||
index_types = param["index.index_types"]
|
||||
nlists = param["index.nlists"]
|
||||
# parse top-k, nq, nprobe
|
||||
top_ks, nqs, nprobes = parser.search_params_parser(param)
|
||||
|
||||
for index_type in index_types:
|
||||
for nlist in nlists:
|
||||
milvus.create_index(index_type, nlist)
|
||||
# preload index
|
||||
milvus.preload_table()
|
||||
# Run query test
|
||||
for nprobe in nprobes:
|
||||
logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
|
||||
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
|
||||
headers = [param["dataset"]]
|
||||
headers.extend([str(top_k) for top_k in top_ks])
|
||||
utils.print_table(headers, nqs, res)
|
||||
|
||||
elif run_type == "stability":
|
||||
for op_type, op_value in definition.items():
|
||||
if op_type != "query":
|
||||
logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
|
||||
break
|
||||
run_count = op_value["run_count"]
|
||||
run_params = op_value["params"]
|
||||
nq = 10000
|
||||
|
||||
for index, param in enumerate(run_params):
|
||||
logger.info("Definition param: %s" % str(param))
|
||||
table_name = param["dataset"]
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
|
||||
# set default test time
|
||||
if "during_time" not in param:
|
||||
during_time = 100 # seconds
|
||||
else:
|
||||
during_time = int(param["during_time"]) * 60
|
||||
# set default query process num
|
||||
if "query_process_num" not in param:
|
||||
query_process_num = 10
|
||||
else:
|
||||
query_process_num = int(param["query_process_num"])
|
||||
milvus = MilvusClient(table_name)
|
||||
# Check has table or not
|
||||
if not milvus.exists_table():
|
||||
logger.warning("Table %s not existed, continue exec next params ..." % table_name)
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
|
||||
while time.time() < start_time + during_time:
|
||||
processes = []
|
||||
# # do query
|
||||
# for i in range(query_process_num):
|
||||
# milvus_instance = MilvusClient(table_name)
|
||||
# top_k = random.choice([x for x in range(1, 100)])
|
||||
# nq = random.choice([x for x in range(1, 1000)])
|
||||
# nprobe = random.choice([x for x in range(1, 500)])
|
||||
# logger.info(nprobe)
|
||||
# p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], 64, run_count, ))
|
||||
# processes.append(p)
|
||||
# p.start()
|
||||
# time.sleep(0.1)
|
||||
# for p in processes:
|
||||
# p.join()
|
||||
milvus_instance = MilvusClient(table_name)
|
||||
top_ks = random.sample([x for x in range(1, 100)], 4)
|
||||
nqs = random.sample([x for x in range(1, 1000)], 3)
|
||||
nprobe = random.choice([x for x in range(1, 500)])
|
||||
res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
|
||||
# milvus_instance = MilvusClient(table_name)
|
||||
status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
||||
if (time.time() - start_time) % 300 == 0:
|
||||
status = milvus_instance.drop_index()
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
||||
index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
|
||||
status = milvus_instance.create_index(index_type, 16384)
|
||||
if not status.OK():
|
||||
logger.error(status.message)
|
|
@ -0,0 +1,131 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import pdb
|
||||
import argparse
|
||||
import logging
|
||||
import utils
|
||||
from yaml import load, dump
|
||||
from logging import handlers
|
||||
from parser import operations_parser
|
||||
from local_runner import LocalRunner
|
||||
from docker_runner import DockerRunner
|
||||
|
||||
DEFAULT_IMAGE = "milvusdb/milvus:latest"
|
||||
LOG_FOLDER = "benchmark_logs"
|
||||
logger = logging.getLogger("milvus_benchmark")
|
||||
|
||||
formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
|
||||
if not os.path.exists(LOG_FOLDER):
|
||||
os.system('mkdir -p %s' % LOG_FOLDER)
|
||||
fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'milvus_benchmark'), "D", 1, 10)
|
||||
fileTimeHandler.suffix = "%Y%m%d.log"
|
||||
fileTimeHandler.setFormatter(formatter)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
fileTimeHandler.setFormatter(formatter)
|
||||
logger.addHandler(fileTimeHandler)
|
||||
|
||||
|
||||
def positive_int(s):
|
||||
i = None
|
||||
try:
|
||||
i = int(s)
|
||||
except ValueError:
|
||||
pass
|
||||
if not i or i < 1:
|
||||
raise argparse.ArgumentTypeError("%r is not a positive integer" % s)
|
||||
return i
|
||||
|
||||
|
||||
# # link random_data if not exists
|
||||
# def init_env():
|
||||
# if not os.path.islink(BINARY_DATA_FOLDER):
|
||||
# try:
|
||||
# os.symlink(SRC_BINARY_DATA_FOLDER, BINARY_DATA_FOLDER)
|
||||
# except Exception as e:
|
||||
# logger.error("Create link failed: %s" % str(e))
|
||||
# sys.exit()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
'--image',
|
||||
help='use the given image')
|
||||
parser.add_argument(
|
||||
'--local',
|
||||
action='store_true',
|
||||
help='use local milvus server')
|
||||
parser.add_argument(
|
||||
"--run-count",
|
||||
default=1,
|
||||
type=positive_int,
|
||||
help="run each db operation times")
|
||||
# performance / stability / accuracy test
|
||||
parser.add_argument(
|
||||
"--run-type",
|
||||
default="performance",
|
||||
help="run type, default performance")
|
||||
parser.add_argument(
|
||||
'--suites',
|
||||
metavar='FILE',
|
||||
help='load test suites from FILE',
|
||||
default='suites.yaml')
|
||||
parser.add_argument(
|
||||
'--ip',
|
||||
help='server ip param for local mode',
|
||||
default='127.0.0.1')
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
help='server port param for local mode',
|
||||
default='19530')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
operations = None
|
||||
# Get all benchmark test suites
|
||||
if args.suites:
|
||||
with open(args.suites) as f:
|
||||
suites_dict = load(f)
|
||||
f.close()
|
||||
# With definition order
|
||||
operations = operations_parser(suites_dict, run_type=args.run_type)
|
||||
|
||||
# init_env()
|
||||
run_params = {"run_count": args.run_count}
|
||||
|
||||
if args.image:
|
||||
# for docker mode
|
||||
if args.local:
|
||||
logger.error("Local mode and docker mode are incompatible arguments")
|
||||
sys.exit(-1)
|
||||
# Docker pull image
|
||||
if not utils.pull_image(args.image):
|
||||
raise Exception('Image %s pull failed' % image)
|
||||
|
||||
# TODO: Check milvus server port is available
|
||||
logger.info("Init: remove all containers created with image: %s" % args.image)
|
||||
utils.remove_all_containers(args.image)
|
||||
runner = DockerRunner(args.image)
|
||||
for operation_type in operations:
|
||||
logger.info("Start run test, test type: %s" % operation_type)
|
||||
run_params["params"] = operations[operation_type]
|
||||
runner.run({operation_type: run_params}, run_type=args.run_type)
|
||||
logger.info("Run params: %s" % str(run_params))
|
||||
|
||||
if args.local:
|
||||
# for local mode
|
||||
ip = args.ip
|
||||
port = args.port
|
||||
|
||||
runner = LocalRunner(ip, port)
|
||||
for operation_type in operations:
|
||||
logger.info("Start run local mode test, test type: %s" % operation_type)
|
||||
run_params["params"] = operations[operation_type]
|
||||
runner.run({operation_type: run_params}, run_type=args.run_type)
|
||||
logger.info("Run params: %s" % str(run_params))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,10 @@
|
|||
from __future__ import absolute_import
|
||||
import pdb
|
||||
import time
|
||||
|
||||
class Base(object):
|
||||
pass
|
||||
|
||||
|
||||
class Insert(Base):
|
||||
pass
|
|
@ -0,0 +1,66 @@
|
|||
import pdb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.parser")
|
||||
|
||||
|
||||
def operations_parser(operations, run_type="performance"):
|
||||
definitions = operations[run_type]
|
||||
return definitions
|
||||
|
||||
|
||||
def table_parser(table_name):
|
||||
tmp = table_name.split("_")
|
||||
# if len(tmp) != 5:
|
||||
# return None
|
||||
data_type = tmp[0]
|
||||
table_size_unit = tmp[1][-1]
|
||||
table_size = tmp[1][0:-1]
|
||||
if table_size_unit == "m":
|
||||
table_size = int(table_size) * 1000000
|
||||
elif table_size_unit == "b":
|
||||
table_size = int(table_size) * 1000000000
|
||||
index_file_size = int(tmp[2])
|
||||
dimension = int(tmp[3])
|
||||
metric_type = str(tmp[4])
|
||||
return (data_type, table_size, index_file_size, dimension, metric_type)
|
||||
|
||||
|
||||
def search_params_parser(param):
|
||||
# parse top-k, set default value if top-k not in param
|
||||
if "top_ks" not in param:
|
||||
top_ks = [10]
|
||||
else:
|
||||
top_ks = param["top_ks"]
|
||||
if isinstance(top_ks, int):
|
||||
top_ks = [top_ks]
|
||||
elif isinstance(top_ks, list):
|
||||
top_ks = list(top_ks)
|
||||
else:
|
||||
logger.warning("Invalid format top-ks: %s" % str(top_ks))
|
||||
|
||||
# parse nqs, set default value if nq not in param
|
||||
if "nqs" not in param:
|
||||
nqs = [10]
|
||||
else:
|
||||
nqs = param["nqs"]
|
||||
if isinstance(nqs, int):
|
||||
nqs = [nqs]
|
||||
elif isinstance(nqs, list):
|
||||
nqs = list(nqs)
|
||||
else:
|
||||
logger.warning("Invalid format nqs: %s" % str(nqs))
|
||||
|
||||
# parse nprobes
|
||||
if "nprobes" not in param:
|
||||
nprobes = [1]
|
||||
else:
|
||||
nprobes = param["nprobes"]
|
||||
if isinstance(nprobes, int):
|
||||
nprobes = [nprobes]
|
||||
elif isinstance(nprobes, list):
|
||||
nprobes = list(nprobes)
|
||||
else:
|
||||
logger.warning("Invalid format nprobes: %s" % str(nprobes))
|
||||
|
||||
return top_ks, nqs, nprobes
|
|
@ -0,0 +1,10 @@
|
|||
# from tablereport import Table
|
||||
# from tablereport.shortcut import write_to_excel
|
||||
|
||||
# RESULT_FOLDER = "results"
|
||||
|
||||
# def create_table(headers, bodys, table_name):
|
||||
# table = Table(header=[headers],
|
||||
# body=[bodys])
|
||||
|
||||
# write_to_excel('%s/%s.xlsx' % (RESULT_FOLDER, table_name), table)
|
|
@ -0,0 +1,6 @@
|
|||
numpy==1.16.3
|
||||
pymilvus>=0.1.18
|
||||
pyyaml==3.12
|
||||
docker==4.0.2
|
||||
tableprint==0.8.0
|
||||
ansicolors==1.1.8
|
|
@ -0,0 +1,219 @@
|
|||
import os
|
||||
import logging
|
||||
import pdb
|
||||
import time
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from client import MilvusClient
|
||||
import utils
|
||||
import parser
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.runner")
|
||||
|
||||
SERVER_HOST_DEFAULT = "127.0.0.1"
|
||||
SERVER_PORT_DEFAULT = 19530
|
||||
VECTORS_PER_FILE = 1000000
|
||||
SIFT_VECTORS_PER_FILE = 100000
|
||||
MAX_NQ = 10001
|
||||
FILE_PREFIX = "binary_"
|
||||
|
||||
RANDOM_SRC_BINARY_DATA_DIR = '/tmp/random/binary_data'
|
||||
SIFT_SRC_DATA_DIR = '/tmp/sift1b/query'
|
||||
SIFT_SRC_BINARY_DATA_DIR = '/tmp/sift1b/binary_data'
|
||||
SIFT_SRC_GROUNDTRUTH_DATA_DIR = '/tmp/sift1b/groundtruth'
|
||||
|
||||
WARM_TOP_K = 1
|
||||
WARM_NQ = 1
|
||||
DEFAULT_DIM = 512
|
||||
|
||||
GROUNDTRUTH_MAP = {
|
||||
"1000000": "idx_1M.ivecs",
|
||||
"2000000": "idx_2M.ivecs",
|
||||
"5000000": "idx_5M.ivecs",
|
||||
"10000000": "idx_10M.ivecs",
|
||||
"20000000": "idx_20M.ivecs",
|
||||
"50000000": "idx_50M.ivecs",
|
||||
"100000000": "idx_100M.ivecs",
|
||||
"200000000": "idx_200M.ivecs",
|
||||
"500000000": "idx_500M.ivecs",
|
||||
"1000000000": "idx_1000M.ivecs",
|
||||
}
|
||||
|
||||
|
||||
def gen_file_name(idx, table_dimension, data_type):
|
||||
s = "%05d" % idx
|
||||
fname = FILE_PREFIX + str(table_dimension) + "d_" + s + ".npy"
|
||||
if data_type == "random":
|
||||
fname = RANDOM_SRC_BINARY_DATA_DIR+'/'+fname
|
||||
elif data_type == "sift":
|
||||
fname = SIFT_SRC_BINARY_DATA_DIR+'/'+fname
|
||||
return fname
|
||||
|
||||
|
||||
def get_vectors_from_binary(nq, dimension, data_type):
|
||||
# use the first file, nq should be less than VECTORS_PER_FILE
|
||||
if nq > MAX_NQ:
|
||||
raise Exception("Over size nq")
|
||||
if data_type == "random":
|
||||
file_name = gen_file_name(0, dimension, data_type)
|
||||
elif data_type == "sift":
|
||||
file_name = SIFT_SRC_DATA_DIR+'/'+'query.npy'
|
||||
data = np.load(file_name)
|
||||
vectors = data[0:nq].tolist()
|
||||
return vectors
|
||||
|
||||
|
||||
class Runner(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def do_insert(self, milvus, table_name, data_type, dimension, size, ni):
|
||||
'''
|
||||
@params:
|
||||
mivlus: server connect instance
|
||||
dimension: table dimensionn
|
||||
# index_file_size: size trigger file merge
|
||||
size: row count of vectors to be insert
|
||||
ni: row count of vectors to be insert each time
|
||||
# store_id: if store the ids returned by call add_vectors or not
|
||||
@return:
|
||||
total_time: total time for all insert operation
|
||||
qps: vectors added per second
|
||||
ni_time: avarage insert operation time
|
||||
'''
|
||||
bi_res = {}
|
||||
total_time = 0.0
|
||||
qps = 0.0
|
||||
ni_time = 0.0
|
||||
if data_type == "random":
|
||||
vectors_per_file = VECTORS_PER_FILE
|
||||
elif data_type == "sift":
|
||||
vectors_per_file = SIFT_VECTORS_PER_FILE
|
||||
if size % vectors_per_file or ni > vectors_per_file:
|
||||
raise Exception("Not invalid table size or ni")
|
||||
file_num = size // vectors_per_file
|
||||
for i in range(file_num):
|
||||
file_name = gen_file_name(i, dimension, data_type)
|
||||
logger.info("Load npy file: %s start" % file_name)
|
||||
data = np.load(file_name)
|
||||
logger.info("Load npy file: %s end" % file_name)
|
||||
loops = vectors_per_file // ni
|
||||
for j in range(loops):
|
||||
vectors = data[j*ni:(j+1)*ni].tolist()
|
||||
ni_start_time = time.time()
|
||||
# start insert vectors
|
||||
start_id = i * vectors_per_file + j * ni
|
||||
end_id = start_id + len(vectors)
|
||||
logger.info("Start id: %s, end id: %s" % (start_id, end_id))
|
||||
ids = [k for k in range(start_id, end_id)]
|
||||
status, ids = milvus.insert(vectors, ids=ids)
|
||||
ni_end_time = time.time()
|
||||
total_time = total_time + ni_end_time - ni_start_time
|
||||
|
||||
qps = round(size / total_time, 2)
|
||||
ni_time = round(total_time / (loops * file_num), 2)
|
||||
bi_res["total_time"] = round(total_time, 2)
|
||||
bi_res["qps"] = qps
|
||||
bi_res["ni_time"] = ni_time
|
||||
return bi_res
|
||||
|
||||
def do_query(self, milvus, table_name, top_ks, nqs, nprobe, run_count):
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||||
|
||||
bi_res = []
|
||||
for index, nq in enumerate(nqs):
|
||||
tmp_res = []
|
||||
for top_k in top_ks:
|
||||
avg_query_time = 0.0
|
||||
total_query_time = 0.0
|
||||
vectors = base_query_vectors[0:nq]
|
||||
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
|
||||
for i in range(run_count):
|
||||
logger.info("Start run query, run %d of %s" % (i+1, run_count))
|
||||
start_time = time.time()
|
||||
status, query_res = milvus.query(vectors, top_k, nprobe)
|
||||
total_query_time = total_query_time + (time.time() - start_time)
|
||||
if status.code:
|
||||
logger.error("Query failed with message: %s" % status.message)
|
||||
avg_query_time = round(total_query_time / run_count, 2)
|
||||
logger.info("Avarage query time: %.2f" % avg_query_time)
|
||||
tmp_res.append(avg_query_time)
|
||||
bi_res.append(tmp_res)
|
||||
return bi_res
|
||||
|
||||
def do_query_ids(self, milvus, table_name, top_k, nq, nprobe):
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||||
vectors = base_query_vectors[0:nq]
|
||||
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
|
||||
status, query_res = milvus.query(vectors, top_k, nprobe)
|
||||
if not status.OK():
|
||||
msg = "Query failed with message: %s" % status.message
|
||||
raise Exception(msg)
|
||||
result_ids = []
|
||||
for result in query_res:
|
||||
tmp = []
|
||||
for item in result:
|
||||
tmp.append(item.id)
|
||||
result_ids.append(tmp)
|
||||
return result_ids
|
||||
|
||||
def do_query_acc(self, milvus, table_name, top_k, nq, nprobe, id_store_name):
|
||||
(data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
|
||||
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
|
||||
vectors = base_query_vectors[0:nq]
|
||||
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
|
||||
status, query_res = milvus.query(vectors, top_k, nprobe)
|
||||
if not status.OK():
|
||||
msg = "Query failed with message: %s" % status.message
|
||||
raise Exception(msg)
|
||||
# if file existed, cover it
|
||||
if os.path.isfile(id_store_name):
|
||||
os.remove(id_store_name)
|
||||
with open(id_store_name, 'a+') as fd:
|
||||
for nq_item in query_res:
|
||||
for item in nq_item:
|
||||
fd.write(str(item.id)+'\t')
|
||||
fd.write('\n')
|
||||
|
||||
# compute and print accuracy
|
||||
def compute_accuracy(self, flat_file_name, index_file_name):
|
||||
flat_id_list = []; index_id_list = []
|
||||
logger.info("Loading flat id file: %s" % flat_file_name)
|
||||
with open(flat_file_name, 'r') as flat_id_fd:
|
||||
for line in flat_id_fd:
|
||||
tmp_list = line.strip("\n").strip().split("\t")
|
||||
flat_id_list.append(tmp_list)
|
||||
logger.info("Loading index id file: %s" % index_file_name)
|
||||
with open(index_file_name) as index_id_fd:
|
||||
for line in index_id_fd:
|
||||
tmp_list = line.strip("\n").strip().split("\t")
|
||||
index_id_list.append(tmp_list)
|
||||
if len(flat_id_list) != len(index_id_list):
|
||||
raise Exception("Flat index result length: <flat: %s, index: %s> not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list)))
|
||||
# get the accuracy
|
||||
return self.get_recall_value(flat_id_list, index_id_list)
|
||||
|
||||
def get_recall_value(self, flat_id_list, index_id_list):
|
||||
"""
|
||||
Use the intersection length
|
||||
"""
|
||||
sum_radio = 0.0
|
||||
for index, item in enumerate(index_id_list):
|
||||
tmp = set(item).intersection(set(flat_id_list[index]))
|
||||
sum_radio = sum_radio + len(tmp) / len(item)
|
||||
return round(sum_radio / len(index_id_list), 3)
|
||||
|
||||
"""
|
||||
Implementation based on:
|
||||
https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py
|
||||
"""
|
||||
def get_groundtruth_ids(self, table_size):
|
||||
fname = GROUNDTRUTH_MAP[str(table_size)]
|
||||
fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname
|
||||
a = np.fromfile(fname, dtype='int32')
|
||||
d = a[0]
|
||||
true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
|
||||
return true_ids
|
|
@ -0,0 +1,38 @@
|
|||
# data sets
|
||||
datasets:
|
||||
hf5:
|
||||
gist-960,sift-128
|
||||
npy:
|
||||
50000000-512, 100000000-512
|
||||
|
||||
operations:
|
||||
# interface: search_vectors
|
||||
query:
|
||||
# dataset: table name you have already created
|
||||
# key starts with "server." need to reconfig and restart server, including nprpbe/nlist/use_blas_threshold/..
|
||||
[
|
||||
# debug
|
||||
# {"dataset": "ip_ivfsq8_1000", "top_ks": [16], "nqs": [1], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
|
||||
|
||||
{"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
|
||||
{"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
|
||||
{"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
|
||||
{"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
|
||||
{"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
|
||||
# {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], "nqs": [1, 10, 100, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
|
||||
]
|
||||
|
||||
# interface: add_vectors
|
||||
insert:
|
||||
# index_type: flat/ivf_flat/ivf_sq8
|
||||
[
|
||||
# debug
|
||||
|
||||
|
||||
{"table_name": "ip_ivf_flat_20m_1024", "table.index_type": "ivf_flat", "server.index_building_threshold": 1024, "table.size": 20000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110},
|
||||
{"table_name": "ip_ivf_sq8_50m_1024", "table.index_type": "ivf_sq8", "server.index_building_threshold": 1024, "table.size": 50000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110},
|
||||
]
|
||||
|
||||
# TODO: interface: build_index
|
||||
build: []
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
|
||||
accuracy:
|
||||
# interface: search_vectors
|
||||
query:
|
||||
[
|
||||
{
|
||||
"dataset": "random_20m_1024_512_ip",
|
||||
# index info
|
||||
"index.index_types": ["flat", "ivf_sq8"],
|
||||
"index.nlists": [16384],
|
||||
"index.metric_types": ["ip"],
|
||||
"nprobes": [1, 16, 64],
|
||||
"top_ks": [64],
|
||||
"nqs": [100],
|
||||
"server.cpu_cache_capacity": 100,
|
||||
"server.resources": ["cpu", "gpu0"],
|
||||
"db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip",
|
||||
},
|
||||
# {
|
||||
# "dataset": "sift_50m_1024_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "index.metric_types": ["l2"],
|
||||
# "nprobes": [1, 16, 64],
|
||||
# "top_ks": [64],
|
||||
# "nqs": [100],
|
||||
# "server.cpu_cache_capacity": 160,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2",
|
||||
# "sift_acc": true
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_50m_1024_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "index.metric_types": ["l2"],
|
||||
# "nprobes": [1, 16, 64],
|
||||
# "top_ks": [64],
|
||||
# "nqs": [100],
|
||||
# "server.cpu_cache_capacity": 160,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2_sq8",
|
||||
# "sift_acc": true
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "index.metric_types": ["l2"],
|
||||
# "nprobes": [1, 16, 64, 128],
|
||||
# "top_ks": [64],
|
||||
# "nqs": [100],
|
||||
# "server.cpu_cache_capacity": 200,
|
||||
# "server.resources": ["cpu"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
|
||||
# "sift_acc": true
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "index.metric_types": ["l2"],
|
||||
# "nprobes": [1, 16, 64, 128],
|
||||
# "top_ks": [64],
|
||||
# "nqs": [100],
|
||||
# "server.cpu_cache_capacity": 200,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
|
||||
# "sift_acc": true
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "index.metric_types": ["l2"],
|
||||
# "nprobes": [1, 16, 64, 128],
|
||||
# "top_ks": [64],
|
||||
# "nqs": [100],
|
||||
# "server.cpu_cache_capacity": 200,
|
||||
# "server.resources": ["cpu", "gpu0", "gpu1"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
|
||||
# "sift_acc": true
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1m_1024_128_l2",
|
||||
# "index.index_types": ["flat", "ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1, 32, 128, 256, 512],
|
||||
# "nqs": 10,
|
||||
# "top_ks": 10,
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_10m_1024_128_l2",
|
||||
# "index.index_types": ["flat", "ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1, 32, 128, 256, 512],
|
||||
# "nqs": 10,
|
||||
# "top_ks": 10,
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 32,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_50m_1024_128_l2",
|
||||
# "index.index_types": ["flat", "ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1, 32, 128, 256, 512],
|
||||
# "nqs": 10,
|
||||
# "top_ks": 10,
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 64,
|
||||
# }
|
||||
|
||||
|
||||
]
|
|
@ -0,0 +1,258 @@
|
|||
performance:
|
||||
|
||||
# interface: add_vectors
|
||||
insert:
|
||||
# index_type: flat/ivf_flat/ivf_sq8/mix_nsg
|
||||
[
|
||||
# debug
|
||||
# data_type / data_size / index_file_size / dimension
|
||||
# data_type: random / ann_sift
|
||||
# data_size: 10m / 1b
|
||||
# {
|
||||
# "table_name": "random_50m_1024_512_ip",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# # "server.resources": ["gpu0", "gpu1"],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
# {
|
||||
# "table_name": "random_5m_1024_512_ip",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# "server.resources": ["gpu0", "gpu1"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/random_5m_1024_512_ip"
|
||||
# },
|
||||
# {
|
||||
# "table_name": "sift_1m_50_128_l2",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# # "server.cpu_cache_capacity": 16,
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
# {
|
||||
# "table_name": "sift_1m_256_128_l2",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# # "server.cpu_cache_capacity": 16,
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# }
|
||||
# {
|
||||
# "table_name": "sift_50m_1024_128_l2",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# # "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "table_name": "sift_100m_1024_128_l2",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# },
|
||||
# {
|
||||
# "table_name": "sift_1b_2048_128_l2",
|
||||
# "ni_per": 100000,
|
||||
# "processes": 5, # multiprocessing
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# }
|
||||
]
|
||||
|
||||
# interface: search_vectors
|
||||
query:
|
||||
# dataset: table name you have already created
|
||||
# key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity ..
|
||||
[
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [8, 32],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 200,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [8, 32],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 200,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [8, 32],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 200,
|
||||
# "server.resources": ["cpu"],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
{
|
||||
"dataset": "random_50m_1024_512_ip",
|
||||
"index.index_types": ["ivf_sq8h"],
|
||||
"index.nlists": [16384],
|
||||
"nprobes": [8],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
"top_ks": [512],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
"nqs": [500],
|
||||
"server.use_blas_threshold": 1100,
|
||||
"server.cpu_cache_capacity": 150,
|
||||
"server.gpu_cache_capacity": 6,
|
||||
"server.resources": ["cpu", "gpu0", "gpu1"],
|
||||
"db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip"
|
||||
},
|
||||
# {
|
||||
# "dataset": "random_50m_1024_512_ip",
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [8, 32],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 150,
|
||||
# "server.resources": ["cpu", "gpu0", "gpu1"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip_sq8"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_20m_1024_512_ip",
|
||||
# "index.index_types": ["flat"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [50],
|
||||
# "top_ks": [64],
|
||||
# "nqs": [10],
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 100,
|
||||
# "server.resources": ["cpu", "gpu0", "gpu1"],
|
||||
# "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_100m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [8, 32],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 250,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_100m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [8, 32],
|
||||
# "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
|
||||
# "nqs": [1, 10, 100, 500, 1000],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 250,
|
||||
# "server.resources": ["cpu"],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# # "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 64
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_500m_1024_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 8, 16, 64, 256, 512, 1000],
|
||||
# "nqs": [1, 100, 500, 800, 1000, 1500],
|
||||
# # "top_ks": [256],
|
||||
# # "nqs": [800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# # "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 120,
|
||||
# "server.resources": ["gpu0", "gpu1"],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8h"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# # "top_ks": [1],
|
||||
# # "nqs": [1],
|
||||
# "top_ks": [256],
|
||||
# "nqs": [800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 110,
|
||||
# "server.resources": ["cpu", "gpu0"],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_50m_1024_512_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# # "top_ks": [256],
|
||||
# # "nqs": [800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 128
|
||||
# },
|
||||
# [
|
||||
# {
|
||||
# "dataset": "sift_1m_50_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1],
|
||||
# "nqs": [1],
|
||||
# "db_path_prefix": "/test/milvus/db_data"
|
||||
# # "processes": 1, # multiprocessing
|
||||
# # "server.use_blas_threshold": 1100,
|
||||
# # "server.cpu_cache_capacity": 256
|
||||
# }
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
|
||||
stability:
|
||||
# interface: search_vectors / add_vectors mix operation
|
||||
query:
|
||||
[
|
||||
{
|
||||
"dataset": "random_20m_1024_512_ip",
|
||||
# "nqs": [1, 10, 100, 1000, 10000],
|
||||
# "pds": [0.1, 0.44, 0.44, 0.02],
|
||||
"query_process_num": 10,
|
||||
# each 10s, do an insertion
|
||||
# "insert_interval": 1,
|
||||
# minutes
|
||||
"during_time": 360,
|
||||
"server.cpu_cache_capacity": 100
|
||||
},
|
||||
]
|
|
@ -0,0 +1,171 @@
|
|||
#"server.resources": ["gpu0", "gpu1"]
|
||||
|
||||
performance:
|
||||
# interface: search_vectors
|
||||
query:
|
||||
# dataset: table name you have already created
|
||||
# key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity ..
|
||||
[
|
||||
# debug
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# # "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# # "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 16,
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_10m_1024_512_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 64
|
||||
# },
|
||||
# {
|
||||
# "dataset": "sift_50m_1024_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1, 32, 128],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# # "top_ks": [256],
|
||||
# # "nqs": [800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 310,
|
||||
# "server.resources": ["gpu0", "gpu1"]
|
||||
# },
|
||||
{
|
||||
"dataset": "sift_1m_1024_128_l2",
|
||||
# index info
|
||||
"index.index_types": ["ivf_sq8"],
|
||||
"index.nlists": [16384],
|
||||
"nprobes": [32],
|
||||
"top_ks": [10],
|
||||
"nqs": [100],
|
||||
# "top_ks": [256],
|
||||
# "nqs": [800],
|
||||
"processes": 1, # multiprocessing
|
||||
"server.use_blas_threshold": 1100,
|
||||
"server.cpu_cache_capacity": 310,
|
||||
"server.resources": ["cpu"]
|
||||
},
|
||||
{
|
||||
"dataset": "sift_1m_1024_128_l2",
|
||||
# index info
|
||||
"index.index_types": ["ivf_sq8"],
|
||||
"index.nlists": [16384],
|
||||
"nprobes": [32],
|
||||
"top_ks": [10],
|
||||
"nqs": [100],
|
||||
# "top_ks": [256],
|
||||
# "nqs": [800],
|
||||
"processes": 1, # multiprocessing
|
||||
"server.use_blas_threshold": 1100,
|
||||
"server.cpu_cache_capacity": 310,
|
||||
"server.resources": ["gpu0"]
|
||||
},
|
||||
{
|
||||
"dataset": "sift_1m_1024_128_l2",
|
||||
# index info
|
||||
"index.index_types": ["ivf_sq8"],
|
||||
"index.nlists": [16384],
|
||||
"nprobes": [32],
|
||||
"top_ks": [10],
|
||||
"nqs": [100],
|
||||
# "top_ks": [256],
|
||||
# "nqs": [800],
|
||||
"processes": 1, # multiprocessing
|
||||
"server.use_blas_threshold": 1100,
|
||||
"server.cpu_cache_capacity": 310,
|
||||
"server.resources": ["gpu0", "gpu1"]
|
||||
},
|
||||
# {
|
||||
# "dataset": "sift_1b_2048_128_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# # "top_ks": [256],
|
||||
# # "nqs": [800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 310
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_50m_1024_512_l2",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# # "top_ks": [256],
|
||||
# # "nqs": [800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 128,
|
||||
# "server.resources": ["gpu0", "gpu1"]
|
||||
# },
|
||||
# {
|
||||
# "dataset": "random_100m_1024_512_ip",
|
||||
# # index info
|
||||
# "index.index_types": ["ivf_sq8"],
|
||||
# "index.nlists": [16384],
|
||||
# "nprobes": [1],
|
||||
# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
# "nqs": [1, 10, 100, 500, 800],
|
||||
# "processes": 1, # multiprocessing
|
||||
# "server.use_blas_threshold": 1100,
|
||||
# "server.cpu_cache_capacity": 256
|
||||
# },
|
||||
]
|
|
@ -0,0 +1,194 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function
|
||||
|
||||
__true_print = print # noqa
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pdb
|
||||
import time
|
||||
import datetime
|
||||
import argparse
|
||||
import threading
|
||||
import logging
|
||||
import docker
|
||||
import multiprocessing
|
||||
import numpy
|
||||
# import psutil
|
||||
from yaml import load, dump
|
||||
import tableprint as tp
|
||||
|
||||
logger = logging.getLogger("milvus_benchmark.utils")
|
||||
|
||||
MULTI_DB_SLAVE_PATH = "/opt/milvus/data2;/opt/milvus/data3"
|
||||
|
||||
|
||||
def get_current_time():
|
||||
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
|
||||
|
||||
|
||||
def print_table(headers, columns, data):
|
||||
bodys = []
|
||||
for index, value in enumerate(columns):
|
||||
tmp = [value]
|
||||
tmp.extend(data[index])
|
||||
bodys.append(tmp)
|
||||
tp.table(bodys, headers)
|
||||
|
||||
|
||||
def modify_config(k, v, type=None, file_path="conf/server_config.yaml", db_slave=None):
|
||||
if not os.path.isfile(file_path):
|
||||
raise Exception('File: %s not found' % file_path)
|
||||
with open(file_path) as f:
|
||||
config_dict = load(f)
|
||||
f.close()
|
||||
if config_dict:
|
||||
if k.find("use_blas_threshold") != -1:
|
||||
config_dict['engine_config']['use_blas_threshold'] = int(v)
|
||||
elif k.find("cpu_cache_capacity") != -1:
|
||||
config_dict['cache_config']['cpu_cache_capacity'] = int(v)
|
||||
elif k.find("gpu_cache_capacity") != -1:
|
||||
config_dict['cache_config']['gpu_cache_capacity'] = int(v)
|
||||
elif k.find("resource_pool") != -1:
|
||||
config_dict['resource_config']['resource_pool'] = v
|
||||
|
||||
if db_slave:
|
||||
config_dict['db_config']['db_slave_path'] = MULTI_DB_SLAVE_PATH
|
||||
with open(file_path, 'w') as f:
|
||||
dump(config_dict, f, default_flow_style=False)
|
||||
f.close()
|
||||
else:
|
||||
raise Exception('Load file:%s error' % file_path)
|
||||
|
||||
|
||||
def pull_image(image):
|
||||
registry = image.split(":")[0]
|
||||
image_tag = image.split(":")[1]
|
||||
client = docker.APIClient(base_url='unix://var/run/docker.sock')
|
||||
logger.info("Start pulling image: %s" % image)
|
||||
return client.pull(registry, image_tag)
|
||||
|
||||
|
||||
def run_server(image, mem_limit=None, timeout=30, test_type="local", volume_name=None, db_slave=None):
|
||||
import colors
|
||||
|
||||
client = docker.from_env()
|
||||
# if mem_limit is None:
|
||||
# mem_limit = psutil.virtual_memory().available
|
||||
# logger.info('Memory limit:', mem_limit)
|
||||
# cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1)
|
||||
# logger.info('Running on CPUs:', cpu_limit)
|
||||
for dir_item in ['logs', 'db']:
|
||||
try:
|
||||
os.mkdir(os.path.abspath(dir_item))
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if test_type == "local":
|
||||
volumes = {
|
||||
os.path.abspath('conf'):
|
||||
{'bind': '/opt/milvus/conf', 'mode': 'ro'},
|
||||
os.path.abspath('logs'):
|
||||
{'bind': '/opt/milvus/logs', 'mode': 'rw'},
|
||||
os.path.abspath('db'):
|
||||
{'bind': '/opt/milvus/db', 'mode': 'rw'},
|
||||
}
|
||||
elif test_type == "remote":
|
||||
if volume_name is None:
|
||||
raise Exception("No volume name")
|
||||
remote_log_dir = volume_name+'/logs'
|
||||
remote_db_dir = volume_name+'/db'
|
||||
|
||||
for dir_item in [remote_log_dir, remote_db_dir]:
|
||||
if not os.path.isdir(dir_item):
|
||||
os.makedirs(dir_item, exist_ok=True)
|
||||
volumes = {
|
||||
os.path.abspath('conf'):
|
||||
{'bind': '/opt/milvus/conf', 'mode': 'ro'},
|
||||
remote_log_dir:
|
||||
{'bind': '/opt/milvus/logs', 'mode': 'rw'},
|
||||
remote_db_dir:
|
||||
{'bind': '/opt/milvus/db', 'mode': 'rw'}
|
||||
}
|
||||
# add volumes
|
||||
if db_slave and isinstance(db_slave, int):
|
||||
for i in range(2, db_slave+1):
|
||||
remote_db_dir = volume_name+'/data'+str(i)
|
||||
if not os.path.isdir(remote_db_dir):
|
||||
os.makedirs(remote_db_dir, exist_ok=True)
|
||||
volumes[remote_db_dir] = {'bind': '/opt/milvus/data'+str(i), 'mode': 'rw'}
|
||||
|
||||
container = client.containers.run(
|
||||
image,
|
||||
volumes=volumes,
|
||||
runtime="nvidia",
|
||||
ports={'19530/tcp': 19530, '8080/tcp': 8080},
|
||||
environment=["OMP_NUM_THREADS=48"],
|
||||
# cpuset_cpus=cpu_limit,
|
||||
# mem_limit=mem_limit,
|
||||
# environment=[""],
|
||||
detach=True)
|
||||
|
||||
def stream_logs():
|
||||
for line in container.logs(stream=True):
|
||||
logger.info(colors.color(line.decode().rstrip(), fg='blue'))
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
t = threading.Thread(target=stream_logs, daemon=True)
|
||||
else:
|
||||
t = threading.Thread(target=stream_logs)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
logger.info('Container: %s started' % container)
|
||||
return container
|
||||
# exit_code = container.wait(timeout=timeout)
|
||||
# # Exit if exit code
|
||||
# if exit_code == 0:
|
||||
# return container
|
||||
# elif exit_code is not None:
|
||||
# print(colors.color(container.logs().decode(), fg='red'))
|
||||
# raise Exception('Child process raised exception %s' % str(exit_code))
|
||||
|
||||
def restart_server(container):
|
||||
client = docker.APIClient(base_url='unix://var/run/docker.sock')
|
||||
|
||||
client.restart(container.name)
|
||||
logger.info('Container: %s restarted' % container.name)
|
||||
return container
|
||||
|
||||
|
||||
def remove_container(container):
|
||||
container.remove(force=True)
|
||||
logger.info('Container: %s removed' % container)
|
||||
|
||||
|
||||
def remove_all_containers(image):
|
||||
client = docker.from_env()
|
||||
try:
|
||||
for container in client.containers.list():
|
||||
if image in container.image.tags:
|
||||
container.stop(timeout=30)
|
||||
container.remove(force=True)
|
||||
except Exception as e:
|
||||
logger.error("Containers removed failed")
|
||||
|
||||
|
||||
def container_exists(image):
|
||||
'''
|
||||
Check if container existed with the given image name
|
||||
@params: image name
|
||||
@return: container if exists
|
||||
'''
|
||||
res = False
|
||||
client = docker.from_env()
|
||||
for container in client.containers.list():
|
||||
if image in container.image.tags:
|
||||
# True
|
||||
res = container
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# print(pull_image('branch-0.3.1-debug'))
|
||||
stop_server()
|
|
@ -0,0 +1,14 @@
|
|||
node_modules
|
||||
npm-debug.log
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
.git
|
||||
.gitignore
|
||||
.env
|
||||
*/bin
|
||||
*/obj
|
||||
README.md
|
||||
LICENSE
|
||||
.vscode
|
||||
__pycache__
|
|
@ -0,0 +1,13 @@
|
|||
.python-version
|
||||
.pytest_cache
|
||||
__pycache__
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
test_out/
|
||||
*.pyc
|
||||
|
||||
db/
|
||||
logs/
|
||||
|
||||
.coverage
|
|
@ -0,0 +1,14 @@
|
|||
FROM python:3.6.8-jessie
|
||||
|
||||
LABEL Name=megasearch_engine_test Version=0.0.1
|
||||
|
||||
WORKDIR /app
|
||||
ADD . /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libc-dev build-essential && \
|
||||
python3 -m pip install -r requirements.txt && \
|
||||
apt-get remove --purge -y
|
||||
|
||||
ENTRYPOINT [ "/app/docker-entrypoint.sh" ]
|
||||
CMD [ "start" ]
|
|
@ -0,0 +1,143 @@
|
|||
# Milvus test cases
|
||||
|
||||
## * Interfaces test
|
||||
|
||||
### 1. 连接测试
|
||||
|
||||
#### 1.1 连接
|
||||
|
||||
| cases | expected |
|
||||
| ---------------- | -------------------------------------------- |
|
||||
| 非法IP 123.0.0.2 | method: connect raise error in given timeout |
|
||||
| 正常 uri | attr: connected assert true |
|
||||
| 非法 uri | method: connect raise error in given timeout |
|
||||
| 最大连接数 | all connection attrs: connected assert true |
|
||||
| | |
|
||||
|
||||
#### 1.2 断开连接
|
||||
|
||||
| cases | expected |
|
||||
| ------------------------ | ------------------- |
|
||||
| 正常连接下,断开连接 | connect raise error |
|
||||
| 正常连接下,重复断开连接 | connect raise error |
|
||||
|
||||
### 2. Table operation
|
||||
|
||||
#### 2.1 表创建
|
||||
|
||||
##### 2.1.1 表名
|
||||
|
||||
| cases | expected |
|
||||
| ------------------------- | ----------- |
|
||||
| 基础功能,参数正常 | status pass |
|
||||
| 表名已存在 | status fail |
|
||||
| 表名:"中文" | status pass |
|
||||
| 表名带特殊字符: "-39fsd-" | status pass |
|
||||
| 表名带空格: "test1 2" | status pass |
|
||||
| invalid dim: 0 | raise error |
|
||||
| invalid dim: -1 | raise error |
|
||||
| invalid dim: 100000000 | raise error |
|
||||
| invalid dim: "string" | raise error |
|
||||
| index_type: 0 | status pass |
|
||||
| index_type: 1 | status pass |
|
||||
| index_type: 2 | status pass |
|
||||
| index_type: string | raise error |
|
||||
| | |
|
||||
|
||||
##### 2.1.2 维数支持
|
||||
|
||||
| cases | expected |
|
||||
| --------------------- | ----------- |
|
||||
| 维数: 0 | raise error |
|
||||
| 维数负数: -1 | raise error |
|
||||
| 维数最大值: 100000000 | raise error |
|
||||
| 维数字符串: "string" | raise error |
|
||||
| | |
|
||||
|
||||
##### 2.1.3 索引类型支持
|
||||
|
||||
| cases | expected |
|
||||
| ---------------- | ----------- |
|
||||
| 索引类型: 0 | status pass |
|
||||
| 索引类型: 1 | status pass |
|
||||
| 索引类型: 2 | status pass |
|
||||
| 索引类型: string | raise error |
|
||||
| | |
|
||||
|
||||
#### 2.2 表说明
|
||||
|
||||
| cases | expected |
|
||||
| ---------------------- | -------------------------------- |
|
||||
| 创建表后,执行describe | 返回结构体,元素与创建表参数一致 |
|
||||
| | |
|
||||
|
||||
#### 2.3 表删除
|
||||
|
||||
| cases | expected |
|
||||
| -------------- | ---------------------- |
|
||||
| 删除已存在表名 | has_table return False |
|
||||
| 删除不存在表名 | status fail |
|
||||
| | |
|
||||
|
||||
#### 2.4 表是否存在
|
||||
|
||||
| cases | expected |
|
||||
| ----------------------- | ------------ |
|
||||
| 存在表,调用has_table | assert true |
|
||||
| 不存在表,调用has_table | assert false |
|
||||
| | |
|
||||
|
||||
#### 2.5 查询表记录条数
|
||||
|
||||
| cases | expected |
|
||||
| -------------------- | ------------------------ |
|
||||
| 空表 | 0 |
|
||||
| 空表插入数据(单条) | 1 |
|
||||
| 空表插入数据(多条) | assert length of vectors |
|
||||
|
||||
#### 2.6 查询表数量
|
||||
|
||||
| cases | expected |
|
||||
| --------------------------------------------- | -------------------------------- |
|
||||
| 两张表,一张空表,一张有数据:调用show tables | assert length of table list == 2 |
|
||||
| | |
|
||||
|
||||
### 3. Add vectors
|
||||
|
||||
| interfaces | cases | expected |
|
||||
| ----------- | --------------------------------------------------------- | ------------------------------------ |
|
||||
| add_vectors | add basic | assert length of ids == nq |
|
||||
| | add vectors into table not existed | status fail |
|
||||
| | dim not match: single vector | status fail |
|
||||
| | dim not match: vector list | status fail |
|
||||
| | single vector element empty | status fail |
|
||||
| | vector list element empty | status fail |
|
||||
| | query immediately after adding | status pass |
|
||||
| | query immediately after sleep 6s | status pass && length of result == 1 |
|
||||
| | concurrent add with multi threads(share one connection) | status pass |
|
||||
| | concurrent add with multi threads(independent connection) | status pass |
|
||||
| | concurrent add with multi process(independent connection) | status pass |
|
||||
| | index_type: 2 | status pass |
|
||||
| | index_type: string | raise error |
|
||||
| | | |
|
||||
|
||||
### 4. Search vectors
|
||||
|
||||
| interfaces | cases | expected |
|
||||
| -------------- | ------------------------------------------------- | -------------------------------- |
|
||||
| search_vectors | search basic(query vector in vectors, top-k<nq) | assert length of result == nq |
|
||||
| | search vectors into table not existed | status fail |
|
||||
| | basic top-k | score of query vectors == 100.0 |
|
||||
| | invalid top-k: 0 | raise error |
|
||||
| | invalid top-k: -1 | raise error |
|
||||
| | invalid top-k: "string" | raise error |
|
||||
| | top-k > nq | assert length of result == nq |
|
||||
| | concurrent search | status pass |
|
||||
| | query_range(get_current_day(), get_current_day()) | assert length of result == nq |
|
||||
| | invalid query_range: "" | raise error |
|
||||
| | query_range(get_last_day(2), get_last_day(1)) | assert length of result == 0 |
|
||||
| | query_range(get_last_day(2), get_current_day()) | assert length of result == nq |
|
||||
| | query_range((get_last_day(2), get_next_day(2)) | assert length of result == nq |
|
||||
| | query_range((get_current_day(), get_next_day(2)) | assert length of result == nq |
|
||||
| | query_range(get_next_day(1), get_next_day(2)) | assert length of result == 0 |
|
||||
| | score: vector[i] = vector[i]+-0.01 | score > 99.9 |
|
|
@ -0,0 +1,14 @@
|
|||
# Requirements
|
||||
* python 3.6.8
|
||||
|
||||
# How to use this Test Project
|
||||
```shell
|
||||
pytest . -q -v
|
||||
```
|
||||
with allure test report
|
||||
```shell
|
||||
pytest --alluredir=test_out . -q -v
|
||||
allure serve test_out
|
||||
```
|
||||
# Contribution getting started
|
||||
* Follow PEP-8 for naming and black for formatting.
|
|
@ -0,0 +1,27 @@
|
|||
* GLOBAL:
|
||||
FORMAT = "%datetime | %level | %logger | %msg"
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log"
|
||||
ENABLED = true
|
||||
TO_FILE = true
|
||||
TO_STANDARD_OUTPUT = false
|
||||
SUBSECOND_PRECISION = 3
|
||||
PERFORMANCE_TRACKING = false
|
||||
MAX_LOG_FILE_SIZE = 209715200 ## Throw log files away after 200MB
|
||||
* DEBUG:
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log"
|
||||
ENABLED = true
|
||||
* WARNING:
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log"
|
||||
* TRACE:
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log"
|
||||
* VERBOSE:
|
||||
FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
|
||||
TO_FILE = false
|
||||
TO_STANDARD_OUTPUT = false
|
||||
## Error logs
|
||||
* ERROR:
|
||||
ENABLED = true
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log"
|
||||
* FATAL:
|
||||
ENABLED = true
|
||||
FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log"
|
|
@ -0,0 +1,32 @@
|
|||
server_config:
|
||||
address: 0.0.0.0
|
||||
port: 19530
|
||||
deploy_mode: single
|
||||
time_zone: UTC+8
|
||||
|
||||
db_config:
|
||||
primary_path: /opt/milvus
|
||||
secondary_path:
|
||||
backend_url: sqlite://:@:/
|
||||
insert_buffer_size: 4
|
||||
build_index_gpu: 0
|
||||
preload_table:
|
||||
|
||||
metric_config:
|
||||
enable_monitor: true
|
||||
collector: prometheus
|
||||
prometheus_config:
|
||||
port: 8080
|
||||
|
||||
cache_config:
|
||||
cpu_cache_capacity: 8
|
||||
cpu_cache_threshold: 0.85
|
||||
cache_insert_data: false
|
||||
|
||||
engine_config:
|
||||
use_blas_threshold: 20
|
||||
|
||||
resource_config:
|
||||
resource_pool:
|
||||
- cpu
|
||||
- gpu0
|
|
@ -0,0 +1,128 @@
|
|||
import socket
|
||||
import pdb
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from utils import gen_unique_str
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
|
||||
index_file_size = 10
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--ip", action="store", default="localhost")
|
||||
parser.addoption("--port", action="store", default=19530)
|
||||
|
||||
|
||||
def check_server_connection(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
connected = True
|
||||
if ip and (ip not in ['localhost', '127.0.0.1']):
|
||||
try:
|
||||
socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
|
||||
except Exception as e:
|
||||
print("Socket connnet failed: %s" % str(e))
|
||||
connected = False
|
||||
return connected
|
||||
|
||||
|
||||
def get_args(request):
|
||||
args = {
|
||||
"ip": request.config.getoption("--ip"),
|
||||
"port": request.config.getoption("--port")
|
||||
}
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def connect(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
milvus = Milvus()
|
||||
try:
|
||||
milvus.connect(host=ip, port=port)
|
||||
except:
|
||||
pytest.exit("Milvus server can not connected, exit pytest ...")
|
||||
|
||||
def fin():
|
||||
try:
|
||||
milvus.disconnect()
|
||||
except:
|
||||
pass
|
||||
|
||||
request.addfinalizer(fin)
|
||||
return milvus
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dis_connect(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
milvus = Milvus()
|
||||
milvus.connect(host=ip, port=port)
|
||||
milvus.disconnect()
|
||||
def fin():
|
||||
try:
|
||||
milvus.disconnect()
|
||||
except:
|
||||
pass
|
||||
|
||||
request.addfinalizer(fin)
|
||||
return milvus
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args(request):
|
||||
ip = request.config.getoption("--ip")
|
||||
port = request.config.getoption("--port")
|
||||
args = {"ip": ip, "port": port}
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def table(request, connect):
|
||||
ori_table_name = getattr(request.module, "table_id", "test")
|
||||
table_name = gen_unique_str(ori_table_name)
|
||||
dim = getattr(request.module, "dim", "128")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
status = connect.create_table(param)
|
||||
# logging.getLogger().info(status)
|
||||
if not status.OK():
|
||||
pytest.exit("Table can not be created, exit pytest ...")
|
||||
|
||||
def teardown():
|
||||
status, table_names = connect.show_tables()
|
||||
for table_name in table_names:
|
||||
connect.delete_table(table_name)
|
||||
|
||||
request.addfinalizer(teardown)
|
||||
|
||||
return table_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ip_table(request, connect):
|
||||
ori_table_name = getattr(request.module, "table_id", "test")
|
||||
table_name = gen_unique_str(ori_table_name)
|
||||
dim = getattr(request.module, "dim", "128")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
status = connect.create_table(param)
|
||||
# logging.getLogger().info(status)
|
||||
if not status.OK():
|
||||
pytest.exit("Table can not be created, exit pytest ...")
|
||||
|
||||
def teardown():
|
||||
status, table_names = connect.show_tables()
|
||||
for table_name in table_names:
|
||||
connect.delete_table(table_name)
|
||||
|
||||
request.addfinalizer(teardown)
|
||||
|
||||
return table_name
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ "$1" = 'start' ]; then
|
||||
tail -f /dev/null
|
||||
fi
|
||||
|
||||
exec "$@"
|
|
@ -0,0 +1,9 @@
|
|||
[pytest]
|
||||
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||
|
||||
log_cli = true
|
||||
log_level = 20
|
||||
|
||||
timeout = 300
|
||||
|
||||
level = 1
|
|
@ -0,0 +1,25 @@
|
|||
astroid==2.2.5
|
||||
atomicwrites==1.3.0
|
||||
attrs==19.1.0
|
||||
importlib-metadata==0.15
|
||||
isort==4.3.20
|
||||
lazy-object-proxy==1.4.1
|
||||
mccabe==0.6.1
|
||||
more-itertools==7.0.0
|
||||
numpy==1.16.3
|
||||
pluggy==0.12.0
|
||||
py==1.8.0
|
||||
pylint==2.3.1
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
six==1.12.0
|
||||
thrift==0.11.0
|
||||
typed-ast==1.3.5
|
||||
wcwidth==0.1.7
|
||||
wrapt==1.11.1
|
||||
zipp==0.5.1
|
||||
pymilvus-test>=0.2.0
|
|
@ -0,0 +1,25 @@
|
|||
astroid==2.2.5
|
||||
atomicwrites==1.3.0
|
||||
attrs==19.1.0
|
||||
importlib-metadata==0.15
|
||||
isort==4.3.20
|
||||
lazy-object-proxy==1.4.1
|
||||
mccabe==0.6.1
|
||||
more-itertools==7.0.0
|
||||
numpy==1.16.3
|
||||
pluggy==0.12.0
|
||||
py==1.8.0
|
||||
pylint==2.3.1
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
six==1.12.0
|
||||
thrift==0.11.0
|
||||
typed-ast==1.3.5
|
||||
wcwidth==0.1.7
|
||||
wrapt==1.11.1
|
||||
zipp==0.5.1
|
||||
pymilvus>=0.1.24
|
|
@ -0,0 +1,24 @@
|
|||
astroid==2.2.5
|
||||
atomicwrites==1.3.0
|
||||
attrs==19.1.0
|
||||
importlib-metadata==0.15
|
||||
isort==4.3.20
|
||||
lazy-object-proxy==1.4.1
|
||||
mccabe==0.6.1
|
||||
more-itertools==7.0.0
|
||||
numpy==1.16.3
|
||||
pluggy==0.12.0
|
||||
py==1.8.0
|
||||
pylint==2.3.1
|
||||
pytest==4.5.0
|
||||
pytest-timeout==1.3.3
|
||||
pytest-repeat==0.8.0
|
||||
allure-pytest==2.7.0
|
||||
pytest-print==0.1.2
|
||||
pytest-level==0.1.1
|
||||
six==1.12.0
|
||||
thrift==0.11.0
|
||||
typed-ast==1.3.5
|
||||
wcwidth==0.1.7
|
||||
wrapt==1.11.1
|
||||
zipp==0.5.1
|
|
@ -0,0 +1,4 @@
|
|||
#/bin/bash
|
||||
|
||||
|
||||
pytest . $@
|
|
@ -0,0 +1,41 @@
|
|||
'''
|
||||
Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
Proprietary and confidential.
|
||||
'''
|
||||
|
||||
'''
|
||||
Test Description:
|
||||
|
||||
This document is only a template to show how to write a auto-test script
|
||||
|
||||
本文档仅仅是个展示如何编写自动化测试脚本的模板
|
||||
|
||||
'''
|
||||
|
||||
import pytest
|
||||
from milvus import Milvus
|
||||
|
||||
|
||||
class TestConnection:
|
||||
def test_connect_localhost(self):
|
||||
|
||||
"""
|
||||
TestCase1.1
|
||||
Test target: This case is to check if the server can be connected.
|
||||
Test method: Call API: milvus.connect to connect local milvus server, ip address: 127.0.0.1 and port: 19530, check the return status
|
||||
Expectation: Return status is OK.
|
||||
|
||||
测试目的:本用例测试客户端是否可以与服务器建立连接
|
||||
测试方法:调用SDK API: milvus.connect方法连接本地服务器,IP地址:127.0.0.1,端口:19530,检查调用返回状态
|
||||
期望结果:返回状态是:OK
|
||||
|
||||
"""
|
||||
|
||||
milvus = Milvus()
|
||||
milvus.connect(host='127.0.0.1', port='19530')
|
||||
assert milvus.connected
|
||||
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,386 @@
|
|||
import pytest
|
||||
from milvus import Milvus
|
||||
import pdb
|
||||
import threading
|
||||
from multiprocessing import Process
|
||||
from utils import *
|
||||
|
||||
__version__ = '0.5.0'
|
||||
CONNECT_TIMEOUT = 12
|
||||
|
||||
|
||||
class TestConnect:
|
||||
|
||||
def local_ip(self, args):
|
||||
'''
|
||||
check if ip is localhost or not
|
||||
'''
|
||||
if not args["ip"] or args["ip"] == 'localhost' or args["ip"] == "127.0.0.1":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def test_disconnect(self, connect):
|
||||
'''
|
||||
target: test disconnect
|
||||
method: disconnect a connected client
|
||||
expected: connect failed after disconnected
|
||||
'''
|
||||
res = connect.disconnect()
|
||||
assert res.OK()
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.server_version()
|
||||
|
||||
def test_disconnect_repeatedly(self, connect, args):
|
||||
'''
|
||||
target: test disconnect repeatedly
|
||||
method: disconnect a connected client, disconnect again
|
||||
expected: raise an error after disconnected
|
||||
'''
|
||||
if not connect.connected():
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus.connect(uri=uri_value)
|
||||
res = milvus.disconnect()
|
||||
with pytest.raises(Exception) as e:
|
||||
res = milvus.disconnect()
|
||||
else:
|
||||
res = connect.disconnect()
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.disconnect()
|
||||
|
||||
def test_connect_correct_ip_port(self, args):
|
||||
'''
|
||||
target: test connect with corrent ip and port value
|
||||
method: set correct ip and port
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = Milvus()
|
||||
milvus.connect(host=args["ip"], port=args["port"])
|
||||
assert milvus.connected()
|
||||
|
||||
def test_connect_connected(self, args):
|
||||
'''
|
||||
target: test connect and disconnect with corrent ip and port value, assert connected value
|
||||
method: set correct ip and port
|
||||
expected: connected is False
|
||||
'''
|
||||
milvus = Milvus()
|
||||
milvus.connect(host=args["ip"], port=args["port"])
|
||||
milvus.disconnect()
|
||||
assert not milvus.connected()
|
||||
|
||||
# TODO: Currently we test with remote IP, localhost testing need to add
|
||||
def _test_connect_ip_localhost(self, args):
|
||||
'''
|
||||
target: test connect with ip value: localhost
|
||||
method: set host localhost
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = Milvus()
|
||||
milvus.connect(host='localhost', port=args["port"])
|
||||
assert milvus.connected()
|
||||
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
def test_connect_wrong_ip_null(self, args):
|
||||
'''
|
||||
target: test connect with wrong ip value
|
||||
method: set host null
|
||||
expected: not use default ip, connected is False
|
||||
'''
|
||||
milvus = Milvus()
|
||||
ip = ""
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(host=ip, port=args["port"], timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
def test_connect_uri(self, args):
|
||||
'''
|
||||
target: test connect with correct uri
|
||||
method: uri format and value are both correct
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus.connect(uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
def test_connect_uri_null(self, args):
|
||||
'''
|
||||
target: test connect with null uri
|
||||
method: uri set null
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = ""
|
||||
|
||||
if self.local_ip(args):
|
||||
milvus.connect(uri=uri_value, timeout=1)
|
||||
assert milvus.connected()
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(uri=uri_value, timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
def test_connect_wrong_uri_wrong_port_null(self, args):
|
||||
'''
|
||||
target: test uri connect with port value wouldn't connected
|
||||
method: set uri port null
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:" % args["ip"]
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(uri=uri_value, timeout=1)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
def test_connect_wrong_uri_wrong_ip_null(self, args):
|
||||
'''
|
||||
target: test uri connect with ip value wouldn't connected
|
||||
method: set uri ip null
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://:%s" % args["port"]
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(uri=uri_value, timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
# TODO: enable
|
||||
def _test_connect_with_multiprocess(self, args):
|
||||
'''
|
||||
target: test uri connect with multiprocess
|
||||
method: set correct uri, test with multiprocessing connecting
|
||||
expected: all connection is connected
|
||||
'''
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
process_num = 4
|
||||
processes = []
|
||||
|
||||
def connect(milvus):
|
||||
milvus.connect(uri=uri_value)
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
p = Process(target=connect, args=(milvus, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_connect_repeatedly(self, args):
|
||||
'''
|
||||
target: test connect repeatedly
|
||||
method: connect again
|
||||
expected: status.code is 0, and status.message shows have connected already
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus.connect(uri=uri_value)
|
||||
|
||||
milvus.connect(uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
def test_connect_disconnect_repeatedly_once(self, args):
|
||||
'''
|
||||
target: test connect and disconnect repeatedly
|
||||
method: disconnect, and then connect, assert connect status
|
||||
expected: status.code is 0
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus.connect(uri=uri_value)
|
||||
|
||||
milvus.disconnect()
|
||||
milvus.connect(uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
def test_connect_disconnect_repeatedly_times(self, args):
|
||||
'''
|
||||
target: test connect and disconnect for 10 times repeatedly
|
||||
method: disconnect, and then connect, assert connect status
|
||||
expected: status.code is 0
|
||||
'''
|
||||
times = 10
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus.connect(uri=uri_value)
|
||||
for i in range(times):
|
||||
milvus.disconnect()
|
||||
milvus.connect(uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
# TODO: enable
|
||||
def _test_connect_disconnect_with_multiprocess(self, args):
|
||||
'''
|
||||
target: test uri connect and disconnect repeatly with multiprocess
|
||||
method: set correct uri, test with multiprocessing connecting and disconnecting
|
||||
expected: all connection is connected after 10 times operation
|
||||
'''
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
process_num = 4
|
||||
processes = []
|
||||
|
||||
def connect(milvus):
|
||||
milvus.connect(uri=uri_value)
|
||||
milvus.disconnect()
|
||||
milvus.connect(uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
p = Process(target=connect, args=(milvus, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_connect_param_priority_no_port(self, args):
|
||||
'''
|
||||
target: both host_ip_port / uri are both given, if port is null, use the uri params
|
||||
method: port set "", check if wrong uri connection is ok
|
||||
expected: connect raise an exception and connected is false
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:19540" % args["ip"]
|
||||
milvus.connect(host=args["ip"], port="", uri=uri_value)
|
||||
assert milvus.connected()
|
||||
|
||||
def test_connect_param_priority_uri(self, args):
|
||||
'''
|
||||
target: both host_ip_port / uri are both given, if host is null, use the uri params
|
||||
method: host set "", check if correct uri connection is ok
|
||||
expected: connected is False
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
def test_connect_param_priority_both_hostip_uri(self, args):
|
||||
'''
|
||||
target: both host_ip_port / uri are both given, and not null, use the uri params
|
||||
method: check if wrong uri connection is ok
|
||||
expected: connect raise an exception and connected is false
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(host=args["ip"], port=19540, uri=uri_value, timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
def _test_add_vector_and_disconnect_concurrently(self):
|
||||
'''
|
||||
Target: test disconnect in the middle of add vectors
|
||||
Method:
|
||||
a. use coroutine or multi-processing, to simulate network crashing
|
||||
b. data_set not too large incase disconnection happens when data is underd-preparing
|
||||
c. data_set not too small incase disconnection happens when data has already been transferred
|
||||
d. make sure disconnection happens when data is in-transport
|
||||
Expected: Failure, get_table_row_count == 0
|
||||
|
||||
'''
|
||||
pass
|
||||
|
||||
def _test_search_vector_and_disconnect_concurrently(self):
|
||||
'''
|
||||
Target: Test disconnect in the middle of search vectors(with large nq and topk)multiple times, and search/add vectors still work
|
||||
Method:
|
||||
a. coroutine or multi-processing, to simulate network crashing
|
||||
b. connect, search and disconnect, repeating many times
|
||||
c. connect and search, add vectors
|
||||
Expected: Successfully searched back, successfully added
|
||||
|
||||
'''
|
||||
pass
|
||||
|
||||
def _test_thread_safe_with_one_connection_shared_in_multi_threads(self):
|
||||
'''
|
||||
Target: test 1 connection thread safe
|
||||
Method: 1 connection shared in multi-threads, all adding vectors, or other things
|
||||
Expected: Functional as one thread
|
||||
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
class TestConnectIPInvalid(object):
|
||||
"""
|
||||
Test connect server with invalid ip
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ips()
|
||||
)
|
||||
def get_invalid_ip(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
def test_connect_with_invalid_ip(self, args, get_invalid_ip):
|
||||
milvus = Milvus()
|
||||
ip = get_invalid_ip
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(host=ip, port=args["port"], timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
|
||||
class TestConnectPortInvalid(object):
|
||||
"""
|
||||
Test connect server with invalid ip
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ports()
|
||||
)
|
||||
def get_invalid_port(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
def test_connect_with_invalid_port(self, args, get_invalid_port):
|
||||
'''
|
||||
target: test ip:port connect with invalid port value
|
||||
method: set port in gen_invalid_ports
|
||||
expected: connected is False
|
||||
'''
|
||||
milvus = Milvus()
|
||||
port = get_invalid_port
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(host=args["ip"], port=port, timeout=1)
|
||||
assert not milvus.connected()
|
||||
|
||||
|
||||
class TestConnectURIInvalid(object):
|
||||
"""
|
||||
Test connect server with invalid uri
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_uris()
|
||||
)
|
||||
def get_invalid_uri(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
def test_connect_with_invalid_uri(self, get_invalid_uri):
|
||||
'''
|
||||
target: test uri connect with invalid uri value
|
||||
method: set port in gen_invalid_uris
|
||||
expected: connected is False
|
||||
'''
|
||||
milvus = Milvus()
|
||||
uri_value = get_invalid_uri
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus.connect(uri=uri_value, timeout=1)
|
||||
assert not milvus.connected()
|
|
@ -0,0 +1,419 @@
|
|||
import time
|
||||
import random
|
||||
import pdb
|
||||
import logging
|
||||
import threading
|
||||
from builtins import Exception
|
||||
from multiprocessing import Pool, Process
|
||||
import pytest
|
||||
|
||||
from milvus import Milvus, IndexType
|
||||
from utils import *
|
||||
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
table_id = "test_delete"
|
||||
DELETE_TIMEOUT = 60
|
||||
vectors = gen_vectors(100, dim)
|
||||
|
||||
class TestDeleteVectorsBase:
|
||||
"""
|
||||
generate invalid query range params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_current_day(), get_current_day()),
|
||||
(get_last_day(1), get_last_day(1)),
|
||||
(get_next_day(1), get_next_day(1))
|
||||
]
|
||||
)
|
||||
def get_invalid_range(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_invalid_range(self, connect, table, get_invalid_range):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with invalid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_invalid_range[0]
|
||||
end_date = get_invalid_range[1]
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.delete_vectors_by_range(table, start_date, end_date)
|
||||
assert not status.OK()
|
||||
|
||||
"""
|
||||
generate valid query range params, no search result
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_last_day(2), get_last_day(1)),
|
||||
(get_last_day(2), get_current_day()),
|
||||
(get_next_day(1), get_next_day(2))
|
||||
]
|
||||
)
|
||||
def get_valid_range_no_result(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range_no_result(self, connect, table, get_valid_range_no_result):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_valid_range_no_result[0]
|
||||
end_date = get_valid_range_no_result[1]
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
time.sleep(2)
|
||||
status = connect.delete_vectors_by_range(table, start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(table)
|
||||
assert result == 100
|
||||
|
||||
"""
|
||||
generate valid query range params, no search result
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_last_day(2), get_next_day(2)),
|
||||
(get_current_day(), get_next_day(2)),
|
||||
]
|
||||
)
|
||||
def get_valid_range(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range(self, connect, table, get_valid_range):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_valid_range[0]
|
||||
end_date = get_valid_range[1]
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
time.sleep(2)
|
||||
status = connect.delete_vectors_by_range(table, start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(table)
|
||||
assert result == 0
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range_index_created(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date))
|
||||
status = connect.delete_vectors_by_range(table, start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(table)
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_no_data(self, connect, table):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params, and no data in db
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
# status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.delete_vectors_by_range(table, start_date, end_date)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_table_not_existed(self, connect):
|
||||
'''
|
||||
target: test delete vectors, table not existed in db
|
||||
method: call `delete_vectors_by_range`, with table not existed
|
||||
expected: return code not 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
table_name = gen_unique_str("not_existed_table")
|
||||
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_table_None(self, connect, table):
|
||||
'''
|
||||
target: test delete vectors, table set Nope
|
||||
method: call `delete_vectors_by_range`, with table value is None
|
||||
expected: return code not 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
table_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range):
|
||||
'''
|
||||
target: test delete vectors is correct or not with multiple tables of L2
|
||||
method: create 50 tables and add vectors into them , then delete vectors
|
||||
in valid range
|
||||
expected: return code 0
|
||||
'''
|
||||
nq = 100
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
for i in range(50):
|
||||
table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
time.sleep(2)
|
||||
start_date = get_valid_range[0]
|
||||
end_date = get_valid_range[1]
|
||||
for i in range(50):
|
||||
status = connect.delete_vectors_by_range(table_list[i], start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(table_list[i])
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestDeleteVectorsIP:
|
||||
"""
|
||||
generate invalid query range params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_current_day(), get_current_day()),
|
||||
(get_last_day(1), get_last_day(1)),
|
||||
(get_next_day(1), get_next_day(1))
|
||||
]
|
||||
)
|
||||
def get_invalid_range(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_invalid_range(self, connect, ip_table, get_invalid_range):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with invalid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_invalid_range[0]
|
||||
end_date = get_invalid_range[1]
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
|
||||
assert not status.OK()
|
||||
|
||||
"""
|
||||
generate valid query range params, no search result
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_last_day(2), get_last_day(1)),
|
||||
(get_last_day(2), get_current_day()),
|
||||
(get_next_day(1), get_next_day(2))
|
||||
]
|
||||
)
|
||||
def get_valid_range_no_result(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range_no_result(self, connect, ip_table, get_valid_range_no_result):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_valid_range_no_result[0]
|
||||
end_date = get_valid_range_no_result[1]
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
time.sleep(2)
|
||||
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(ip_table)
|
||||
assert result == 100
|
||||
|
||||
"""
|
||||
generate valid query range params, no search result
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_last_day(2), get_next_day(2)),
|
||||
(get_current_day(), get_next_day(2)),
|
||||
]
|
||||
)
|
||||
def get_valid_range(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range(self, connect, ip_table, get_valid_range):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_valid_range[0]
|
||||
end_date = get_valid_range[1]
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
time.sleep(2)
|
||||
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(ip_table)
|
||||
assert result == 0
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range_index_created(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date))
|
||||
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(ip_table)
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_no_data(self, connect, ip_table):
|
||||
'''
|
||||
target: test delete vectors, no index created
|
||||
method: call `delete_vectors_by_range`, with valid date params, and no data in db
|
||||
expected: return code 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
# status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_table_None(self, connect, ip_table):
|
||||
'''
|
||||
target: test delete vectors, table set Nope
|
||||
method: call `delete_vectors_by_range`, with table value is None
|
||||
expected: return code not 0
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
table_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range):
|
||||
'''
|
||||
target: test delete vectors is correct or not with multiple tables of IP
|
||||
method: create 50 tables and add vectors into them , then delete vectors
|
||||
in valid range
|
||||
expected: return code 0
|
||||
'''
|
||||
nq = 100
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
for i in range(50):
|
||||
table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
time.sleep(2)
|
||||
start_date = get_valid_range[0]
|
||||
end_date = get_valid_range[1]
|
||||
for i in range(50):
|
||||
status = connect.delete_vectors_by_range(table_list[i], start_date, end_date)
|
||||
assert status.OK()
|
||||
status, result = connect.get_table_row_count(table_list[i])
|
||||
assert result == 0
|
||||
|
||||
class TestDeleteVectorsParamsInvalid:
|
||||
|
||||
"""
|
||||
Test search table with invalid table names
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_table_names()
|
||||
)
|
||||
def get_table_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_delete_vectors_table_invalid_name(self, connect, get_table_name):
|
||||
'''
|
||||
'''
|
||||
start_date = get_current_day()
|
||||
end_date = get_next_day(2)
|
||||
table_name = get_table_name
|
||||
logging.getLogger().info(table_name)
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
status = connect.delete_vectors_by_range(table_name, start_date, end_date)
|
||||
assert not status.OK()
|
||||
|
||||
"""
|
||||
Test search table with invalid query ranges
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_query_ranges()
|
||||
)
|
||||
def get_query_ranges(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(DELETE_TIMEOUT)
|
||||
def test_delete_vectors_range_invalid(self, connect, table, get_query_ranges):
|
||||
'''
|
||||
target: test search fuction, with the wrong query_range
|
||||
method: search with query_range
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
start_date = get_query_ranges[0][0]
|
||||
end_date = get_query_ranges[0][1]
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
logging.getLogger().info(get_query_ranges)
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.delete_vectors_by_range(table, start_date, end_date)
|
|
@ -0,0 +1,966 @@
|
|||
"""
|
||||
For testing index operations, including `create_index`, `describe_index` and `drop_index` interfaces
|
||||
"""
|
||||
import logging
|
||||
import pytest
|
||||
import time
|
||||
import pdb
|
||||
import threading
|
||||
from multiprocessing import Pool, Process
|
||||
import numpy
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
from utils import *
|
||||
|
||||
nb = 100000
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
vectors = gen_vectors(nb, dim)
|
||||
vectors /= numpy.linalg.norm(vectors)
|
||||
vectors = vectors.tolist()
|
||||
BUILD_TIMEOUT = 60
|
||||
nprobe = 1
|
||||
|
||||
|
||||
class TestIndexBase:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index_params()
|
||||
)
|
||||
def get_simple_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_without_connect(self, dis_connect, table):
|
||||
'''
|
||||
target: test create index without connection
|
||||
method: create table and add vectors in it, check if added successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.create_index(table, random.choice(gen_index_params()))
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_search_with_query_vectors(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test create index interface, search with more query vectors
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
logging.getLogger().info(connect.describe_index(table))
|
||||
query_vecs = [vectors[0], vectors[1], vectors[2]]
|
||||
top_k = 5
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
logging.getLogger().info(result)
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
@pytest.mark.level(2)
|
||||
def _test_create_index_multiprocessing(self, connect, table, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
|
||||
def build(connect):
|
||||
status = connect.create_index(table)
|
||||
assert status.OK()
|
||||
|
||||
process_num = 8
|
||||
processes = []
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
|
||||
for i in range(process_num):
|
||||
m = Milvus()
|
||||
m.connect(uri=uri)
|
||||
p = Process(target=build, args=(m,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
assert result[0][0].distance == 0.0
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def _test_create_index_multiprocessing_multitable(self, connect, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
process_num = 8
|
||||
loop_num = 8
|
||||
processes = []
|
||||
|
||||
table = []
|
||||
j = 0
|
||||
while j < (process_num*loop_num):
|
||||
table_name = gen_unique_str("test_create_index_multiprocessing")
|
||||
table.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_type': IndexType.FLAT,
|
||||
'store_raw_vector': False}
|
||||
connect.create_table(param)
|
||||
j = j + 1
|
||||
|
||||
def create_index():
|
||||
i = 0
|
||||
while i < loop_num:
|
||||
# assert connect.has_table(table[ids*process_num+i])
|
||||
status, ids = connect.add_vectors(table[ids*process_num+i], vectors)
|
||||
|
||||
status = connect.create_index(table[ids*process_num+i])
|
||||
assert status.OK()
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
assert result[0][0].distance == 0.0
|
||||
i = i + 1
|
||||
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
|
||||
for i in range(process_num):
|
||||
m = Milvus()
|
||||
m.connect(uri=uri)
|
||||
ids = i
|
||||
p = Process(target=create_index, args=(m,ids))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_create_index_table_not_existed(self, connect):
|
||||
'''
|
||||
target: test create index interface when table name not existed
|
||||
method: create table and add vectors in it, create index with an random table_name
|
||||
, make sure the table name not in index
|
||||
expected: return code not equals to 0, create index failed
|
||||
'''
|
||||
table_name = gen_unique_str(self.__class__.__name__)
|
||||
status = connect.create_index(table_name, random.choice(gen_index_params()))
|
||||
assert not status.OK()
|
||||
|
||||
def test_create_index_table_None(self, connect):
|
||||
'''
|
||||
target: test create index interface when table name is None
|
||||
method: create table and add vectors in it, create index with an table_name: None
|
||||
expected: return code not equals to 0, create index failed
|
||||
'''
|
||||
table_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_index(table_name, random.choice(gen_index_params()))
|
||||
|
||||
def test_create_index_no_vectors(self, connect, table):
|
||||
'''
|
||||
target: test create index interface when there is no vectors in table
|
||||
method: create table and add no vectors in it, and then create index
|
||||
expected: return code equals to 0
|
||||
'''
|
||||
status = connect.create_index(table, random.choice(gen_index_params()))
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_no_vectors_then_add_vectors(self, connect, table):
|
||||
'''
|
||||
target: test create index interface when there is no vectors in table, and does not affect the subsequent process
|
||||
method: create table and add no vectors in it, and then create index, add vectors in it
|
||||
expected: return code equals to 0
|
||||
'''
|
||||
status = connect.create_index(table, random.choice(gen_index_params()))
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_same_index_repeatedly(self, connect, table):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the same create_index params
|
||||
method: create index after index have been built
|
||||
expected: return code success, and search ok
|
||||
'''
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
index_params = random.choice(gen_index_params())
|
||||
# index_params = get_index_params
|
||||
status = connect.create_index(table, index_params)
|
||||
status = connect.create_index(table, index_params)
|
||||
assert status.OK()
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_different_index_repeatedly(self, connect, table):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the different create_index params
|
||||
method: create another index with different index_params after index have been built
|
||||
expected: return code 0, and describe index result equals with the second index params
|
||||
'''
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
index_params = random.sample(gen_index_params(), 2)
|
||||
logging.getLogger().info(index_params)
|
||||
status = connect.create_index(table, index_params[0])
|
||||
status = connect.create_index(table, index_params[1])
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
assert result._nlist == index_params[1]["nlist"]
|
||||
assert result._table_name == table
|
||||
assert result._index_type == index_params[1]["index_type"]
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `describe_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_describe_index(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test describe index interface
|
||||
method: create table and add vectors in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == index_params["nlist"]
|
||||
assert result._table_name == table
|
||||
assert result._index_type == index_params["index_type"]
|
||||
|
||||
def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params):
|
||||
'''
|
||||
target: test create, describe and drop index interface with multiple tables of L2
|
||||
method: create tables and add vectors in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
nq = 100
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
for i in range(10):
|
||||
table_name = gen_unique_str('test_create_index_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
index_params = get_simple_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
status = connect.create_index(table_name, index_params)
|
||||
assert status.OK()
|
||||
|
||||
for i in range(10):
|
||||
status, result = connect.describe_index(table_list[i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == index_params["nlist"]
|
||||
assert result._table_name == table_list[i]
|
||||
assert result._index_type == index_params["index_type"]
|
||||
|
||||
for i in range(10):
|
||||
status = connect.drop_index(table_list[i])
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table_list[i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[i]
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_describe_index_without_connect(self, dis_connect, table):
|
||||
'''
|
||||
target: test describe index without connection
|
||||
method: describe index, and check if describe successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.describe_index(table)
|
||||
|
||||
def test_describe_index_table_not_existed(self, connect):
|
||||
'''
|
||||
target: test describe index interface when table name not existed
|
||||
method: create table and add vectors in it, create index with an random table_name
|
||||
, make sure the table name not in index
|
||||
expected: return code not equals to 0, describe index failed
|
||||
'''
|
||||
table_name = gen_unique_str(self.__class__.__name__)
|
||||
status, result = connect.describe_index(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_describe_index_table_None(self, connect):
|
||||
'''
|
||||
target: test describe index interface when table name is None
|
||||
method: create table and add vectors in it, create index with an table_name: None
|
||||
expected: return code not equals to 0, describe index failed
|
||||
'''
|
||||
table_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.describe_index(table_name)
|
||||
|
||||
def test_describe_index_not_create(self, connect, table):
|
||||
'''
|
||||
target: test describe index interface when index not created
|
||||
method: create table and add vectors in it, create index with an random table_name
|
||||
, make sure the table name not in index
|
||||
expected: return code not equals to 0, describe index failed
|
||||
'''
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
assert status.OK()
|
||||
# assert result._nlist == index_params["nlist"]
|
||||
# assert result._table_name == table
|
||||
# assert result._index_type == index_params["index_type"]
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_drop_index(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create table and add vectors in it, create index, call drop index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
index_params = get_index_params
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
def test_drop_index_repeatly(self, connect, table, get_simple_index_params):
|
||||
'''
|
||||
target: test drop index repeatly
|
||||
method: create index, call drop index, and drop again
|
||||
expected: return code 0
|
||||
'''
|
||||
index_params = get_simple_index_params
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(table)
|
||||
assert status.OK()
|
||||
status = connect.drop_index(table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_index_without_connect(self, dis_connect, table):
|
||||
'''
|
||||
target: test drop index without connection
|
||||
method: drop index, and check if drop successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.drop_index(table)
|
||||
|
||||
def test_drop_index_table_not_existed(self, connect):
|
||||
'''
|
||||
target: test drop index interface when table name not existed
|
||||
method: create table and add vectors in it, create index with an random table_name
|
||||
, make sure the table name not in index, and then drop it
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
table_name = gen_unique_str(self.__class__.__name__)
|
||||
status = connect.drop_index(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
def test_drop_index_table_None(self, connect):
|
||||
'''
|
||||
target: test drop index interface when table name is None
|
||||
method: create table and add vectors in it, create index with an table_name: None
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
table_name = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.drop_index(table_name)
|
||||
|
||||
def test_drop_index_table_not_create(self, connect, table):
|
||||
'''
|
||||
target: test drop index interface when index not created
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
index_params = random.choice(gen_index_params())
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
# no create index
|
||||
status = connect.drop_index(table)
|
||||
logging.getLogger().info(status)
|
||||
assert status.OK()
|
||||
|
||||
def test_create_drop_index_repeatly(self, connect, table, get_simple_index_params):
|
||||
'''
|
||||
target: test create / drop index repeatly, use the same index params
|
||||
method: create index, drop index, four times
|
||||
expected: return code 0
|
||||
'''
|
||||
index_params = get_simple_index_params
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
for i in range(2):
|
||||
status = connect.create_index(table, index_params)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
def test_create_drop_index_repeatly_different_index_params(self, connect, table):
|
||||
'''
|
||||
target: test create / drop index repeatly, use the different index params
|
||||
method: create index, drop index, four times, each tme use different index_params to create index
|
||||
expected: return code 0
|
||||
'''
|
||||
index_params = random.sample(gen_index_params(), 2)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
for i in range(2):
|
||||
status = connect.create_index(table, index_params[i])
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
|
||||
class TestIndexIP:
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_simple_index_params()
|
||||
)
|
||||
def get_simple_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test create index interface
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_without_connect(self, dis_connect, ip_table):
|
||||
'''
|
||||
target: test create index without connection
|
||||
method: create table and add vectors in it, check if added successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.create_index(ip_table, random.choice(gen_index_params()))
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_search_with_query_vectors(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test create index interface, search with more query vectors
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
logging.getLogger().info(connect.describe_index(ip_table))
|
||||
query_vecs = [vectors[0], vectors[1], vectors[2]]
|
||||
top_k = 5
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
# logging.getLogger().info(result)
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
@pytest.mark.level(2)
|
||||
def _test_create_index_multiprocessing(self, connect, ip_table, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
|
||||
def build(connect):
|
||||
status = connect.create_index(ip_table)
|
||||
assert status.OK()
|
||||
|
||||
process_num = 8
|
||||
processes = []
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
|
||||
for i in range(process_num):
|
||||
m = Milvus()
|
||||
m.connect(uri=uri)
|
||||
p = Process(target=build, args=(m,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
assert result[0][0].distance == 0.0
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def _test_create_index_multiprocessing_multitable(self, connect, args):
|
||||
'''
|
||||
target: test create index interface with multiprocess
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code equals to 0, and search success
|
||||
'''
|
||||
process_num = 8
|
||||
loop_num = 8
|
||||
processes = []
|
||||
|
||||
table = []
|
||||
j = 0
|
||||
while j < (process_num*loop_num):
|
||||
table_name = gen_unique_str("test_create_index_multiprocessing")
|
||||
table.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim}
|
||||
connect.create_table(param)
|
||||
j = j + 1
|
||||
|
||||
def create_index():
|
||||
i = 0
|
||||
while i < loop_num:
|
||||
# assert connect.has_table(table[ids*process_num+i])
|
||||
status, ids = connect.add_vectors(table[ids*process_num+i], vectors)
|
||||
|
||||
status = connect.create_index(table[ids*process_num+i])
|
||||
assert status.OK()
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
assert result[0][0].distance == 0.0
|
||||
i = i + 1
|
||||
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
|
||||
for i in range(process_num):
|
||||
m = Milvus()
|
||||
m.connect(uri=uri)
|
||||
ids = i
|
||||
p = Process(target=create_index, args=(m,ids))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_create_index_no_vectors(self, connect, ip_table):
|
||||
'''
|
||||
target: test create index interface when there is no vectors in table
|
||||
method: create table and add no vectors in it, and then create index
|
||||
expected: return code equals to 0
|
||||
'''
|
||||
status = connect.create_index(ip_table, random.choice(gen_index_params()))
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_index_no_vectors_then_add_vectors(self, connect, ip_table):
|
||||
'''
|
||||
target: test create index interface when there is no vectors in table, and does not affect the subsequent process
|
||||
method: create table and add no vectors in it, and then create index, add vectors in it
|
||||
expected: return code equals to 0
|
||||
'''
|
||||
status = connect.create_index(ip_table, random.choice(gen_index_params()))
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_same_index_repeatedly(self, connect, ip_table):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the same create_index params
|
||||
method: create index after index have been built
|
||||
expected: return code success, and search ok
|
||||
'''
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
index_params = random.choice(gen_index_params())
|
||||
# index_params = get_index_params
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
assert status.OK()
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 1
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == top_k
|
||||
|
||||
@pytest.mark.timeout(BUILD_TIMEOUT)
|
||||
def test_create_different_index_repeatedly(self, connect, ip_table):
|
||||
'''
|
||||
target: check if index can be created repeatedly, with the different create_index params
|
||||
method: create another index with different index_params after index have been built
|
||||
expected: return code 0, and describe index result equals with the second index params
|
||||
'''
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
index_params = random.sample(gen_index_params(), 2)
|
||||
logging.getLogger().info(index_params)
|
||||
status = connect.create_index(ip_table, index_params[0])
|
||||
status = connect.create_index(ip_table, index_params[1])
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
assert result._nlist == index_params[1]["nlist"]
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == index_params[1]["index_type"]
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `describe_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_describe_index(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test describe index interface
|
||||
method: create table and add vectors in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == index_params["nlist"]
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == index_params["index_type"]
|
||||
|
||||
def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params):
|
||||
'''
|
||||
target: test create, describe and drop index interface with multiple tables of IP
|
||||
method: create tables and add vectors in it, create index, call describe index
|
||||
expected: return code 0, and index instructure
|
||||
'''
|
||||
nq = 100
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
for i in range(10):
|
||||
table_name = gen_unique_str('test_create_index_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
index_params = get_simple_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
status = connect.create_index(table_name, index_params)
|
||||
assert status.OK()
|
||||
|
||||
for i in range(10):
|
||||
status, result = connect.describe_index(table_list[i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == index_params["nlist"]
|
||||
assert result._table_name == table_list[i]
|
||||
assert result._index_type == index_params["index_type"]
|
||||
|
||||
for i in range(10):
|
||||
status = connect.drop_index(table_list[i])
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(table_list[i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[i]
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_describe_index_without_connect(self, dis_connect, ip_table):
|
||||
'''
|
||||
target: test describe index without connection
|
||||
method: describe index, and check if describe successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.describe_index(ip_table)
|
||||
|
||||
def test_describe_index_not_create(self, connect, ip_table):
|
||||
'''
|
||||
target: test describe index interface when index not created
|
||||
method: create table and add vectors in it, create index with an random table_name
|
||||
, make sure the table name not in index
|
||||
expected: return code not equals to 0, describe index failed
|
||||
'''
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
assert status.OK()
|
||||
# assert result._nlist == index_params["nlist"]
|
||||
# assert result._table_name == table
|
||||
# assert result._index_type == index_params["index_type"]
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `drop_index` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_drop_index(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test drop index interface
|
||||
method: create table and add vectors in it, create index, call drop index
|
||||
expected: return code 0, and default index param
|
||||
'''
|
||||
index_params = get_index_params
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(ip_table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params):
|
||||
'''
|
||||
target: test drop index repeatly
|
||||
method: create index, call drop index, and drop again
|
||||
expected: return code 0
|
||||
'''
|
||||
index_params = get_simple_index_params
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(ip_table)
|
||||
assert status.OK()
|
||||
status = connect.drop_index(ip_table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_drop_index_without_connect(self, dis_connect, ip_table):
|
||||
'''
|
||||
target: test drop index without connection
|
||||
method: drop index, and check if drop successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.drop_index(ip_table, random.choice(gen_index_params()))
|
||||
|
||||
def test_drop_index_table_not_create(self, connect, ip_table):
|
||||
'''
|
||||
target: test drop index interface when index not created
|
||||
method: create table and add vectors in it, create index
|
||||
expected: return code not equals to 0, drop index failed
|
||||
'''
|
||||
index_params = random.choice(gen_index_params())
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
# no create index
|
||||
status = connect.drop_index(ip_table)
|
||||
logging.getLogger().info(status)
|
||||
assert status.OK()
|
||||
|
||||
def test_create_drop_index_repeatly(self, connect, ip_table, get_simple_index_params):
|
||||
'''
|
||||
target: test create / drop index repeatly, use the same index params
|
||||
method: create index, drop index, four times
|
||||
expected: return code 0
|
||||
'''
|
||||
index_params = get_simple_index_params
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
for i in range(2):
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(ip_table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table):
|
||||
'''
|
||||
target: test create / drop index repeatly, use the different index params
|
||||
method: create index, drop index, four times, each tme use different index_params to create index
|
||||
expected: return code 0
|
||||
'''
|
||||
index_params = random.sample(gen_index_params(), 2)
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
for i in range(2):
|
||||
status = connect.create_index(ip_table, index_params[i])
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
assert result._nlist == index_params[i]["nlist"]
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == index_params[i]["index_type"]
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
status = connect.drop_index(ip_table)
|
||||
assert status.OK()
|
||||
status, result = connect.describe_index(ip_table)
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == ip_table
|
||||
assert result._index_type == IndexType.FLAT
|
||||
|
||||
|
||||
class TestIndexTableInvalid(object):
|
||||
"""
|
||||
Test create / describe / drop index interfaces with invalid table names
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_table_names()
|
||||
)
|
||||
def get_table_name(self, request):
|
||||
yield request.param
|
||||
|
||||
# @pytest.mark.level(1)
|
||||
def test_create_index_with_invalid_tablename(self, connect, get_table_name):
|
||||
table_name = get_table_name
|
||||
status = connect.create_index(table_name, random.choice(gen_index_params()))
|
||||
assert not status.OK()
|
||||
|
||||
# @pytest.mark.level(1)
|
||||
def test_describe_index_with_invalid_tablename(self, connect, get_table_name):
|
||||
table_name = get_table_name
|
||||
status, result = connect.describe_index(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
# @pytest.mark.level(1)
|
||||
def test_drop_index_with_invalid_tablename(self, connect, get_table_name):
|
||||
table_name = get_table_name
|
||||
status = connect.drop_index(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
|
||||
class TestCreateIndexParamsInvalid(object):
|
||||
"""
|
||||
Test Building index with invalid table names, table names not in db
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_index_with_invalid_index_params(self, connect, table, get_index_params):
|
||||
index_params = get_index_params
|
||||
index_type = index_params["index_type"]
|
||||
nlist = index_params["nlist"]
|
||||
logging.getLogger().info(index_params)
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
# if not isinstance(index_type, int) or not isinstance(nlist, int):
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_index(table, index_params)
|
||||
# else:
|
||||
# status = connect.create_index(table, index_params)
|
||||
# assert not status.OK()
|
|
@ -0,0 +1,180 @@
|
|||
import pdb
|
||||
import copy
|
||||
import pytest
|
||||
import threading
|
||||
import datetime
|
||||
import logging
|
||||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
import numpy
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
from utils import *
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
table_id = "test_mix"
|
||||
add_interval_time = 2
|
||||
vectors = gen_vectors(100000, dim)
|
||||
vectors /= numpy.linalg.norm(vectors)
|
||||
vectors = vectors.tolist()
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
epsilon = 0.0001
|
||||
index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
|
||||
|
||||
|
||||
class TestMixBase:
|
||||
|
||||
# TODO: enable
|
||||
def _test_search_during_createIndex(self, args):
|
||||
loops = 100000
|
||||
table = "test_search_during_createIndex"
|
||||
query_vecs = [vectors[0], vectors[1]]
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
id_0 = 0; id_1 = 0
|
||||
milvus_instance = Milvus()
|
||||
milvus_instance.connect(uri=uri)
|
||||
milvus_instance.create_table({'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2})
|
||||
for i in range(10):
|
||||
status, ids = milvus_instance.add_vectors(table, vectors)
|
||||
# logging.getLogger().info(ids)
|
||||
if i == 0:
|
||||
id_0 = ids[0]; id_1 = ids[1]
|
||||
def create_index(milvus_instance):
|
||||
logging.getLogger().info("In create index")
|
||||
status = milvus_instance.create_index(table, index_params)
|
||||
logging.getLogger().info(status)
|
||||
status, result = milvus_instance.describe_index(table)
|
||||
logging.getLogger().info(result)
|
||||
def add_vectors(milvus_instance):
|
||||
logging.getLogger().info("In add vectors")
|
||||
status, ids = milvus_instance.add_vectors(table, vectors)
|
||||
logging.getLogger().info(status)
|
||||
def search(milvus_instance):
|
||||
for i in range(loops):
|
||||
status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
logging.getLogger().info(status)
|
||||
assert result[0][0].id == id_0
|
||||
assert result[1][0].id == id_1
|
||||
milvus_instance = Milvus()
|
||||
milvus_instance.connect(uri=uri)
|
||||
p_search = Process(target=search, args=(milvus_instance, ))
|
||||
p_search.start()
|
||||
milvus_instance = Milvus()
|
||||
milvus_instance.connect(uri=uri)
|
||||
p_create = Process(target=add_vectors, args=(milvus_instance, ))
|
||||
p_create.start()
|
||||
p_create.join()
|
||||
|
||||
def test_mix_multi_tables(self, connect):
|
||||
'''
|
||||
target: test functions with multiple tables of different metric_types and index_types
|
||||
method: create 60 tables which 30 are L2 and the other are IP, add vectors into them
|
||||
and test describe index and search
|
||||
expected: status ok
|
||||
'''
|
||||
nq = 10000
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
idx = []
|
||||
|
||||
#create table and add vectors
|
||||
for i in range(30):
|
||||
table_name = gen_unique_str('test_mix_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
idx.append(ids[0])
|
||||
idx.append(ids[10])
|
||||
idx.append(ids[20])
|
||||
assert status.OK()
|
||||
for i in range(30):
|
||||
table_name = gen_unique_str('test_mix_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
status, ids = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
idx.append(ids[0])
|
||||
idx.append(ids[10])
|
||||
idx.append(ids[20])
|
||||
assert status.OK()
|
||||
time.sleep(2)
|
||||
|
||||
#create index
|
||||
for i in range(10):
|
||||
index_params = {'index_type': IndexType.FLAT, 'nlist': 16384}
|
||||
status = connect.create_index(table_list[i], index_params)
|
||||
assert status.OK()
|
||||
status = connect.create_index(table_list[30 + i], index_params)
|
||||
assert status.OK()
|
||||
index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
|
||||
status = connect.create_index(table_list[10 + i], index_params)
|
||||
assert status.OK()
|
||||
status = connect.create_index(table_list[40 + i], index_params)
|
||||
assert status.OK()
|
||||
index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': 16384}
|
||||
status = connect.create_index(table_list[20 + i], index_params)
|
||||
assert status.OK()
|
||||
status = connect.create_index(table_list[50 + i], index_params)
|
||||
assert status.OK()
|
||||
|
||||
#describe index
|
||||
for i in range(10):
|
||||
status, result = connect.describe_index(table_list[i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[i]
|
||||
assert result._index_type == IndexType.FLAT
|
||||
status, result = connect.describe_index(table_list[10 + i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[10 + i]
|
||||
assert result._index_type == IndexType.IVFLAT
|
||||
status, result = connect.describe_index(table_list[20 + i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[20 + i]
|
||||
assert result._index_type == IndexType.IVF_SQ8
|
||||
status, result = connect.describe_index(table_list[30 + i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[30 + i]
|
||||
assert result._index_type == IndexType.FLAT
|
||||
status, result = connect.describe_index(table_list[40 + i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[40 + i]
|
||||
assert result._index_type == IndexType.IVFLAT
|
||||
status, result = connect.describe_index(table_list[50 + i])
|
||||
logging.getLogger().info(result)
|
||||
assert result._nlist == 16384
|
||||
assert result._table_name == table_list[50 + i]
|
||||
assert result._index_type == IndexType.IVF_SQ8
|
||||
|
||||
#search
|
||||
query_vecs = [vectors[0], vectors[10], vectors[20]]
|
||||
for i in range(60):
|
||||
table = table_list[i]
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
for j in range(len(query_vecs)):
|
||||
assert len(result[j]) == top_k
|
||||
for j in range(len(query_vecs)):
|
||||
assert check_result(result[j], idx[3 * i + j])
|
||||
|
||||
def check_result(result, id):
|
||||
if len(result) >= 5:
|
||||
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
|
||||
else:
|
||||
return id in (i.id for i in result)
|
|
@ -0,0 +1,77 @@
|
|||
import logging
|
||||
import pytest
|
||||
|
||||
__version__ = '0.5.0'
|
||||
|
||||
|
||||
class TestPing:
|
||||
def test_server_version(self, connect):
|
||||
'''
|
||||
target: test get the server version
|
||||
method: call the server_version method after connected
|
||||
expected: version should be the pymilvus version
|
||||
'''
|
||||
status, res = connect.server_version()
|
||||
assert res == __version__
|
||||
|
||||
def test_server_status(self, connect):
|
||||
'''
|
||||
target: test get the server status
|
||||
method: call the server_status method after connected
|
||||
expected: status returned should be ok
|
||||
'''
|
||||
status, msg = connect.server_status()
|
||||
assert status.OK()
|
||||
|
||||
def _test_server_cmd_with_params_version(self, connect):
|
||||
'''
|
||||
target: test cmd: version
|
||||
method: cmd = "version" ...
|
||||
expected: when cmd = 'version', return version of server;
|
||||
'''
|
||||
cmd = "version"
|
||||
status, msg = connect.cmd(cmd)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(msg)
|
||||
assert status.OK()
|
||||
assert msg == __version__
|
||||
|
||||
def _test_server_cmd_with_params_others(self, connect):
|
||||
'''
|
||||
target: test cmd: lalala
|
||||
method: cmd = "lalala" ...
|
||||
expected: when cmd = 'version', return version of server;
|
||||
'''
|
||||
cmd = "rm -rf test"
|
||||
status, msg = connect.cmd(cmd)
|
||||
logging.getLogger().info(status)
|
||||
logging.getLogger().info(msg)
|
||||
assert status.OK()
|
||||
# assert msg == __version__
|
||||
|
||||
def test_connected(self, connect):
|
||||
assert connect.connected()
|
||||
|
||||
|
||||
class TestPingDisconnect:
|
||||
def test_server_version(self, dis_connect):
|
||||
'''
|
||||
target: test get the server version, after disconnect
|
||||
method: call the server_version method after connected
|
||||
expected: version should not be the pymilvus version
|
||||
'''
|
||||
res = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status, res = connect.server_version()
|
||||
assert res is None
|
||||
|
||||
def test_server_status(self, dis_connect):
|
||||
'''
|
||||
target: test get the server status, after disconnect
|
||||
method: call the server_status method after connected
|
||||
expected: status returned should be not ok
|
||||
'''
|
||||
status = None
|
||||
with pytest.raises(Exception) as e:
|
||||
status, msg = connect.server_status()
|
||||
assert status is None
|
|
@ -0,0 +1,650 @@
|
|||
import pdb
|
||||
import copy
|
||||
import pytest
|
||||
import threading
|
||||
import datetime
|
||||
import logging
|
||||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
import numpy
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
from utils import *
|
||||
|
||||
dim = 128
|
||||
table_id = "test_search"
|
||||
add_interval_time = 2
|
||||
vectors = gen_vectors(100, dim)
|
||||
# vectors /= numpy.linalg.norm(vectors)
|
||||
# vectors = vectors.tolist()
|
||||
nrpobe = 1
|
||||
epsilon = 0.001
|
||||
|
||||
|
||||
class TestSearchBase:
|
||||
def init_data(self, connect, table, nb=100):
|
||||
'''
|
||||
Generate vectors and add it in table, before search vectors
|
||||
'''
|
||||
global vectors
|
||||
if nb == 100:
|
||||
add_vectors = vectors
|
||||
else:
|
||||
add_vectors = gen_vectors(nb, dim)
|
||||
# add_vectors /= numpy.linalg.norm(add_vectors)
|
||||
# add_vectors = add_vectors.tolist()
|
||||
status, ids = connect.add_vectors(table, add_vectors)
|
||||
sleep(add_interval_time)
|
||||
return add_vectors, ids
|
||||
|
||||
"""
|
||||
generate valid create_index params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
generate top-k params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[1, 99, 101, 1024, 2048, 2049]
|
||||
)
|
||||
def get_top_k(self, request):
|
||||
yield request.param
|
||||
|
||||
|
||||
def test_search_top_k_flat_index(self, connect, table, get_top_k):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||
method: search with the given vectors, check the result
|
||||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
query_vec = [vectors[0]]
|
||||
top_k = get_top_k
|
||||
nprobe = 1
|
||||
status, result = connect.search_vectors(table, top_k, nrpobe, query_vec)
|
||||
if top_k <= 2048:
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert check_result(result[0], ids[0])
|
||||
else:
|
||||
assert not status.OK()
|
||||
|
||||
def test_search_l2_index_params(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
method: search with the given vectors, check the result
|
||||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
status = connect.create_index(table, index_params)
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 10
|
||||
nprobe = 1
|
||||
status, result = connect.search_vectors(table, top_k, nrpobe, query_vec)
|
||||
logging.getLogger().info(result)
|
||||
if top_k <= 1024:
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert check_result(result[0], ids[0])
|
||||
assert result[0][0].distance <= epsilon
|
||||
else:
|
||||
assert not status.OK()
|
||||
|
||||
def test_search_ip_index_params(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
method: search with the given vectors, check the result
|
||||
expected: search status ok, and the length of the result is top_k
|
||||
'''
|
||||
|
||||
index_params = get_index_params
|
||||
logging.getLogger().info(index_params)
|
||||
vectors, ids = self.init_data(connect, ip_table)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
query_vec = [vectors[0]]
|
||||
top_k = 10
|
||||
nprobe = 1
|
||||
status, result = connect.search_vectors(ip_table, top_k, nrpobe, query_vec)
|
||||
logging.getLogger().info(result)
|
||||
|
||||
if top_k <= 1024:
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert check_result(result[0], ids[0])
|
||||
assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance)
|
||||
else:
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_vectors_without_connect(self, dis_connect, table):
|
||||
'''
|
||||
target: test search vectors without connection
|
||||
method: use dis connected instance, call search method and check if search successfully
|
||||
expected: raise exception
|
||||
'''
|
||||
query_vectors = [vectors[0]]
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
with pytest.raises(Exception) as e:
|
||||
status, ids = dis_connect.search_vectors(table, top_k, nprobe, query_vectors)
|
||||
|
||||
def test_search_table_name_not_existed(self, connect, table):
|
||||
'''
|
||||
target: search table not existed
|
||||
method: search with the random table_name, which is not in db
|
||||
expected: status not ok
|
||||
'''
|
||||
table_name = gen_unique_str("not_existed_table")
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
query_vecs = [vectors[0]]
|
||||
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
|
||||
assert not status.OK()
|
||||
|
||||
def test_search_table_name_None(self, connect, table):
|
||||
'''
|
||||
target: search table that table name is None
|
||||
method: search with the table_name: None
|
||||
expected: status not ok
|
||||
'''
|
||||
table_name = None
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
query_vecs = [vectors[0]]
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
|
||||
|
||||
def test_search_top_k_query_records(self, connect, table):
|
||||
'''
|
||||
target: test search fuction, with search params: query_records
|
||||
method: search with the given query_records, which are subarrays of the inserted vectors
|
||||
expected: status ok and the returned vectors should be query_records
|
||||
'''
|
||||
top_k = 10
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
query_vecs = [vectors[0],vectors[55],vectors[99]]
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
for i in range(len(query_vecs)):
|
||||
assert len(result[i]) == top_k
|
||||
assert result[i][0].distance <= epsilon
|
||||
|
||||
"""
|
||||
generate invalid query range params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_current_day(), get_current_day()),
|
||||
(get_last_day(1), get_last_day(1)),
|
||||
(get_next_day(1), get_next_day(1))
|
||||
]
|
||||
)
|
||||
def get_invalid_range(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_search_invalid_query_ranges(self, connect, table, get_invalid_range):
|
||||
'''
|
||||
target: search table with query ranges
|
||||
method: search with the same query ranges
|
||||
expected: status not ok
|
||||
'''
|
||||
top_k = 2
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
query_vecs = [vectors[0]]
|
||||
query_ranges = [get_invalid_range]
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
|
||||
assert not status.OK()
|
||||
assert len(result) == 0
|
||||
|
||||
"""
|
||||
generate valid query range params, no search result
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_last_day(2), get_last_day(1)),
|
||||
(get_last_day(2), get_current_day()),
|
||||
(get_next_day(1), get_next_day(2))
|
||||
]
|
||||
)
|
||||
def get_valid_range_no_result(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_search_valid_query_ranges_no_result(self, connect, table, get_valid_range_no_result):
|
||||
'''
|
||||
target: search table with normal query ranges, but no data in db
|
||||
method: search with query ranges (low, low)
|
||||
expected: length of result is 0
|
||||
'''
|
||||
top_k = 2
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
query_vecs = [vectors[0]]
|
||||
query_ranges = [get_valid_range_no_result]
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
|
||||
assert status.OK()
|
||||
assert len(result) == 0
|
||||
|
||||
"""
|
||||
generate valid query range params, no search result
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(get_last_day(2), get_next_day(2)),
|
||||
(get_current_day(), get_next_day(2)),
|
||||
]
|
||||
)
|
||||
def get_valid_range(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_search_valid_query_ranges(self, connect, table, get_valid_range):
|
||||
'''
|
||||
target: search table with normal query ranges, but no data in db
|
||||
method: search with query ranges (low, normal)
|
||||
expected: length of result is 0
|
||||
'''
|
||||
top_k = 2
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
query_vecs = [vectors[0]]
|
||||
query_ranges = [get_valid_range]
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
|
||||
assert status.OK()
|
||||
assert len(result) == 1
|
||||
assert result[0][0].distance <= epsilon
|
||||
|
||||
def test_search_distance_l2_flat_index(self, connect, table):
|
||||
'''
|
||||
target: search table, and check the result: distance
|
||||
method: compare the return distance value with value computed with Euclidean
|
||||
expected: the return distance equals to the computed value
|
||||
'''
|
||||
nb = 2
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, table, nb=nb)
|
||||
query_vecs = [[0.50 for i in range(dim)]]
|
||||
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
|
||||
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
|
||||
|
||||
def test_search_distance_ip_flat_index(self, connect, ip_table):
|
||||
'''
|
||||
target: search ip_table, and check the result: distance
|
||||
method: compare the return distance value with value computed with Inner product
|
||||
expected: the return distance equals to the computed value
|
||||
'''
|
||||
nb = 2
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, ip_table, nb=nb)
|
||||
index_params = {
|
||||
"index_type": IndexType.FLAT,
|
||||
"nlist": 16384
|
||||
}
|
||||
connect.create_index(ip_table, index_params)
|
||||
logging.getLogger().info(connect.describe_index(ip_table))
|
||||
query_vecs = [[0.50 for i in range(dim)]]
|
||||
distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
|
||||
distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
|
||||
assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
|
||||
|
||||
def test_search_distance_ip_index_params(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: search table, and check the result: distance
|
||||
method: compare the return distance value with value computed with Inner product
|
||||
expected: the return distance equals to the computed value
|
||||
'''
|
||||
top_k = 2
|
||||
nprobe = 1
|
||||
vectors, ids = self.init_data(connect, ip_table, nb=2)
|
||||
index_params = get_index_params
|
||||
connect.create_index(ip_table, index_params)
|
||||
logging.getLogger().info(connect.describe_index(ip_table))
|
||||
query_vecs = [[0.50 for i in range(dim)]]
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
|
||||
distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
|
||||
distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
|
||||
assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
|
||||
|
||||
# TODO: enable
|
||||
# @pytest.mark.repeat(5)
|
||||
@pytest.mark.timeout(30)
|
||||
def _test_search_concurrent(self, connect, table):
|
||||
vectors, ids = self.init_data(connect, table)
|
||||
thread_num = 10
|
||||
nb = 100
|
||||
top_k = 10
|
||||
threads = []
|
||||
query_vecs = vectors[nb//2:nb]
|
||||
def search():
|
||||
status, result = connect.search_vectors(table, top_k, query_vecs)
|
||||
assert len(result) == len(query_vecs)
|
||||
for i in range(len(query_vecs)):
|
||||
assert result[i][0].id in ids
|
||||
assert result[i][0].distance == 0.0
|
||||
for i in range(thread_num):
|
||||
x = threading.Thread(target=search, args=())
|
||||
threads.append(x)
|
||||
x.start()
|
||||
for th in threads:
|
||||
th.join()
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.timeout(30)
|
||||
def _test_search_concurrent_multiprocessing(self, args):
|
||||
'''
|
||||
target: test concurrent search with multiprocessess
|
||||
method: search with 10 processes, each process uses dependent connection
|
||||
expected: status ok and the returned vectors should be query_records
|
||||
'''
|
||||
nb = 100
|
||||
top_k = 10
|
||||
process_num = 4
|
||||
processes = []
|
||||
table = gen_unique_str("test_search_concurrent_multiprocessing")
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_type': IndexType.FLAT,
|
||||
'store_raw_vector': False}
|
||||
# create table
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
milvus.create_table(param)
|
||||
vectors, ids = self.init_data(milvus, table, nb=nb)
|
||||
query_vecs = vectors[nb//2:nb]
|
||||
def search(milvus):
|
||||
status, result = milvus.search_vectors(table, top_k, query_vecs)
|
||||
assert len(result) == len(query_vecs)
|
||||
for i in range(len(query_vecs)):
|
||||
assert result[i][0].id in ids
|
||||
assert result[i][0].distance == 0.0
|
||||
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=search, args=(milvus, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_search_multi_table_L2(search, args):
|
||||
'''
|
||||
target: test search multi tables of L2
|
||||
method: add vectors into 10 tables, and search
|
||||
expected: search status ok, the length of result
|
||||
'''
|
||||
num = 10
|
||||
top_k = 10
|
||||
nprobe = 1
|
||||
tables = []
|
||||
idx = []
|
||||
for i in range(num):
|
||||
table = gen_unique_str("test_add_multitable_%d" % i)
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_file_size': 10,
|
||||
'metric_type': MetricType.L2}
|
||||
# create table
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
milvus.create_table(param)
|
||||
status, ids = milvus.add_vectors(table, vectors)
|
||||
assert status.OK()
|
||||
assert len(ids) == len(vectors)
|
||||
tables.append(table)
|
||||
idx.append(ids[0])
|
||||
idx.append(ids[10])
|
||||
idx.append(ids[20])
|
||||
time.sleep(6)
|
||||
query_vecs = [vectors[0], vectors[10], vectors[20]]
|
||||
# start query from random table
|
||||
for i in range(num):
|
||||
table = tables[i]
|
||||
status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
for j in range(len(query_vecs)):
|
||||
assert len(result[j]) == top_k
|
||||
for j in range(len(query_vecs)):
|
||||
assert check_result(result[j], idx[3 * i + j])
|
||||
|
||||
def test_search_multi_table_IP(search, args):
|
||||
'''
|
||||
target: test search multi tables of IP
|
||||
method: add vectors into 10 tables, and search
|
||||
expected: search status ok, the length of result
|
||||
'''
|
||||
num = 10
|
||||
top_k = 10
|
||||
nprobe = 1
|
||||
tables = []
|
||||
idx = []
|
||||
for i in range(num):
|
||||
table = gen_unique_str("test_add_multitable_%d" % i)
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_file_size': 10,
|
||||
'metric_type': MetricType.L2}
|
||||
# create table
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
milvus.create_table(param)
|
||||
status, ids = milvus.add_vectors(table, vectors)
|
||||
assert status.OK()
|
||||
assert len(ids) == len(vectors)
|
||||
tables.append(table)
|
||||
idx.append(ids[0])
|
||||
idx.append(ids[10])
|
||||
idx.append(ids[20])
|
||||
time.sleep(6)
|
||||
query_vecs = [vectors[0], vectors[10], vectors[20]]
|
||||
# start query from random table
|
||||
for i in range(num):
|
||||
table = tables[i]
|
||||
status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert status.OK()
|
||||
assert len(result) == len(query_vecs)
|
||||
for j in range(len(query_vecs)):
|
||||
assert len(result[j]) == top_k
|
||||
for j in range(len(query_vecs)):
|
||||
assert check_result(result[j], idx[3 * i + j])
|
||||
"""
|
||||
******************************************************************
|
||||
# The following cases are used to test `search_vectors` function
|
||||
# with invalid table_name top-k / nprobe / query_range
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
class TestSearchParamsInvalid(object):
|
||||
index_params = random.choice(gen_index_params())
|
||||
logging.getLogger().info(index_params)
|
||||
|
||||
def init_data(self, connect, table, nb=100):
|
||||
'''
|
||||
Generate vectors and add it in table, before search vectors
|
||||
'''
|
||||
global vectors
|
||||
if nb == 100:
|
||||
add_vectors = vectors
|
||||
else:
|
||||
add_vectors = gen_vectors(nb, dim)
|
||||
status, ids = connect.add_vectors(table, add_vectors)
|
||||
sleep(add_interval_time)
|
||||
return add_vectors, ids
|
||||
|
||||
"""
|
||||
Test search table with invalid table names
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_table_names()
|
||||
)
|
||||
def get_table_name(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_tablename(self, connect, get_table_name):
|
||||
table_name = get_table_name
|
||||
logging.getLogger().info(table_name)
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
|
||||
assert not status.OK()
|
||||
|
||||
"""
|
||||
Test search table with invalid top-k
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_top_ks()
|
||||
)
|
||||
def get_top_k(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_top_k(self, connect, table, get_top_k):
|
||||
'''
|
||||
target: test search fuction, with the wrong top_k
|
||||
method: search with top_k
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
top_k = get_top_k
|
||||
logging.getLogger().info(top_k)
|
||||
nprobe = 1
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
res = connect.server_version()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_top_k_ip(self, connect, ip_table, get_top_k):
|
||||
'''
|
||||
target: test search fuction, with the wrong top_k
|
||||
method: search with top_k
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
top_k = get_top_k
|
||||
logging.getLogger().info(top_k)
|
||||
nprobe = 1
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
|
||||
res = connect.server_version()
|
||||
|
||||
"""
|
||||
Test search table with invalid nprobe
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_nprobes()
|
||||
)
|
||||
def get_nprobes(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_nrpobe(self, connect, table, get_nprobes):
|
||||
'''
|
||||
target: test search fuction, with the wrong top_k
|
||||
method: search with top_k
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
top_k = 1
|
||||
nprobe = get_nprobes
|
||||
logging.getLogger().info(nprobe)
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
if isinstance(nprobe, int) and nprobe > 0:
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
assert not status.OK()
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_with_invalid_nrpobe_ip(self, connect, ip_table, get_nprobes):
|
||||
'''
|
||||
target: test search fuction, with the wrong top_k
|
||||
method: search with top_k
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
top_k = 1
|
||||
nprobe = get_nprobes
|
||||
logging.getLogger().info(nprobe)
|
||||
query_vecs = gen_vectors(1, dim)
|
||||
if isinstance(nprobe, int) and nprobe > 0:
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
|
||||
assert not status.OK()
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
|
||||
|
||||
"""
|
||||
Test search table with invalid query ranges
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_query_ranges()
|
||||
)
|
||||
def get_query_ranges(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_flat_with_invalid_query_range(self, connect, table, get_query_ranges):
|
||||
'''
|
||||
target: test search fuction, with the wrong query_range
|
||||
method: search with query_range
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
query_vecs = [vectors[0]]
|
||||
query_ranges = get_query_ranges
|
||||
logging.getLogger().info(query_ranges)
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(table, 1, nprobe, query_vecs, query_ranges=query_ranges)
|
||||
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_flat_with_invalid_query_range_ip(self, connect, ip_table, get_query_ranges):
|
||||
'''
|
||||
target: test search fuction, with the wrong query_range
|
||||
method: search with query_range
|
||||
expected: raise an error, and the connection is normal
|
||||
'''
|
||||
top_k = 1
|
||||
nprobe = 1
|
||||
query_vecs = [vectors[0]]
|
||||
query_ranges = get_query_ranges
|
||||
logging.getLogger().info(query_ranges)
|
||||
with pytest.raises(Exception) as e:
|
||||
status, result = connect.search_vectors(ip_table, 1, nprobe, query_vecs, query_ranges=query_ranges)
|
||||
|
||||
|
||||
def check_result(result, id):
|
||||
if len(result) >= 5:
|
||||
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
|
||||
else:
|
||||
return id in (i.id for i in result)
|
|
@ -0,0 +1,883 @@
|
|||
import random
|
||||
import pdb
|
||||
import pytest
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
import numpy
|
||||
from milvus import Milvus
|
||||
from milvus import IndexType, MetricType
|
||||
from utils import *
|
||||
|
||||
dim = 128
|
||||
delete_table_interval_time = 3
|
||||
index_file_size = 10
|
||||
vectors = gen_vectors(100, dim)
|
||||
|
||||
|
||||
class TestTable:
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `create_table` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_create_table(self, connect):
|
||||
'''
|
||||
target: test create normal table
|
||||
method: create table with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
status = connect.create_table(param)
|
||||
assert status.OK()
|
||||
|
||||
def test_create_table_ip(self, connect):
|
||||
'''
|
||||
target: test create normal table
|
||||
method: create table with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
status = connect.create_table(param)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_table_without_connection(self, dis_connect):
|
||||
'''
|
||||
target: test create table, without connection
|
||||
method: create table with correct params, with a disconnected instance
|
||||
expected: create raise exception
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.create_table(param)
|
||||
|
||||
def test_create_table_existed(self, connect):
|
||||
'''
|
||||
target: test create table but the table name have already existed
|
||||
method: create table with the same table_name
|
||||
expected: create status return not ok
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
status = connect.create_table(param)
|
||||
status = connect.create_table(param)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_table_existed_ip(self, connect):
|
||||
'''
|
||||
target: test create table but the table name have already existed
|
||||
method: create table with the same table_name
|
||||
expected: create status return not ok
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
status = connect.create_table(param)
|
||||
status = connect.create_table(param)
|
||||
assert not status.OK()
|
||||
|
||||
def test_create_table_None(self, connect):
|
||||
'''
|
||||
target: test create table but the table name is None
|
||||
method: create table, param table_name is None
|
||||
expected: create raise error
|
||||
'''
|
||||
param = {'table_name': None,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_table(param)
|
||||
|
||||
def test_create_table_no_dimension(self, connect):
|
||||
'''
|
||||
target: test create table with no dimension params
|
||||
method: create table with corrent params
|
||||
expected: create status return ok
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_table(param)
|
||||
|
||||
def test_create_table_no_file_size(self, connect):
|
||||
'''
|
||||
target: test create table with no index_file_size params
|
||||
method: create table with corrent params
|
||||
expected: create status return ok, use default 1024
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'metric_type': MetricType.L2}
|
||||
status = connect.create_table(param)
|
||||
logging.getLogger().info(status)
|
||||
status, result = connect.describe_table(table_name)
|
||||
logging.getLogger().info(result)
|
||||
assert result.index_file_size == 1024
|
||||
|
||||
def test_create_table_no_metric_type(self, connect):
|
||||
'''
|
||||
target: test create table with no metric_type params
|
||||
method: create table with corrent params
|
||||
expected: create status return ok, use default L2
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size}
|
||||
status = connect.create_table(param)
|
||||
status, result = connect.describe_table(table_name)
|
||||
logging.getLogger().info(result)
|
||||
assert result.metric_type == MetricType.L2
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `describe_table` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_table_describe_result(self, connect):
|
||||
'''
|
||||
target: test describe table created with correct params
|
||||
method: create table, assert the value returned by describe method
|
||||
expected: table_name equals with the table name created
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status, res = connect.describe_table(table_name)
|
||||
assert res.table_name == table_name
|
||||
assert res.metric_type == MetricType.L2
|
||||
|
||||
def test_table_describe_table_name_ip(self, connect):
|
||||
'''
|
||||
target: test describe table created with correct params
|
||||
method: create table, assert the value returned by describe method
|
||||
expected: table_name equals with the table name created
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
status, res = connect.describe_table(table_name)
|
||||
assert res.table_name == table_name
|
||||
assert res.metric_type == MetricType.IP
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
def _test_table_describe_table_name_multiprocessing(self, connect, args):
|
||||
'''
|
||||
target: test describe table created with multiprocess
|
||||
method: create table, assert the value returned by describe method
|
||||
expected: table_name equals with the table name created
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
|
||||
def describetable(milvus):
|
||||
status, res = milvus.describe_table(table_name)
|
||||
assert res.table_name == table_name
|
||||
|
||||
process_num = 4
|
||||
processes = []
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=describetable, args=(milvus,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_table_describe_without_connection(self, table, dis_connect):
|
||||
'''
|
||||
target: test describe table, without connection
|
||||
method: describe table with correct params, with a disconnected instance
|
||||
expected: describe raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.describe_table(table)
|
||||
|
||||
def test_table_describe_dimension(self, connect):
|
||||
'''
|
||||
target: test describe table created with correct params
|
||||
method: create table, assert the dimention value returned by describe method
|
||||
expected: dimention equals with dimention when created
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim+1,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status, res = connect.describe_table(table_name)
|
||||
assert res.dimension == dim+1
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `delete_table` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_delete_table(self, connect, table):
|
||||
'''
|
||||
target: test delete table created with correct params
|
||||
method: create table and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no table in tables
|
||||
'''
|
||||
status = connect.delete_table(table)
|
||||
assert not connect.has_table(table)
|
||||
|
||||
def test_delete_table_ip(self, connect, ip_table):
|
||||
'''
|
||||
target: test delete table created with correct params
|
||||
method: create table and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no table in tables
|
||||
'''
|
||||
status = connect.delete_table(ip_table)
|
||||
assert not connect.has_table(ip_table)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_table_delete_without_connection(self, table, dis_connect):
|
||||
'''
|
||||
target: test describe table, without connection
|
||||
method: describe table with correct params, with a disconnected instance
|
||||
expected: describe raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.delete_table(table)
|
||||
|
||||
def test_delete_table_not_existed(self, connect):
|
||||
'''
|
||||
target: test delete table not in index
|
||||
method: delete all tables, and delete table again,
|
||||
assert the value returned by delete method
|
||||
expected: status not ok
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
status = connect.delete_table(table_name)
|
||||
assert not status.code==0
|
||||
|
||||
def test_delete_table_repeatedly(self, connect):
|
||||
'''
|
||||
target: test delete table created with correct params
|
||||
method: create table and delete new table repeatedly,
|
||||
assert the value returned by delete method
|
||||
expected: create ok and delete ok
|
||||
'''
|
||||
loops = 1
|
||||
for i in range(loops):
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status = connect.delete_table(table_name)
|
||||
time.sleep(1)
|
||||
assert not connect.has_table(table_name)
|
||||
|
||||
def test_delete_create_table_repeatedly(self, connect):
|
||||
'''
|
||||
target: test delete and create the same table repeatedly
|
||||
method: try to create the same table and delete repeatedly,
|
||||
assert the value returned by delete method
|
||||
expected: create ok and delete ok
|
||||
'''
|
||||
loops = 5
|
||||
for i in range(loops):
|
||||
table_name = "test_table"
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status = connect.delete_table(table_name)
|
||||
time.sleep(2)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_delete_create_table_repeatedly_ip(self, connect):
|
||||
'''
|
||||
target: test delete and create the same table repeatedly
|
||||
method: try to create the same table and delete repeatedly,
|
||||
assert the value returned by delete method
|
||||
expected: create ok and delete ok
|
||||
'''
|
||||
loops = 5
|
||||
for i in range(loops):
|
||||
table_name = "test_table"
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
status = connect.delete_table(table_name)
|
||||
time.sleep(2)
|
||||
assert status.OK()
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
def _test_delete_table_multiprocessing(self, args):
|
||||
'''
|
||||
target: test delete table with multiprocess
|
||||
method: create table and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no table in tables
|
||||
'''
|
||||
process_num = 6
|
||||
processes = []
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
|
||||
def deletetable(milvus):
|
||||
status = milvus.delete_table(table)
|
||||
# assert not status.code==0
|
||||
assert milvus.has_table(table)
|
||||
assert status.OK()
|
||||
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=deletetable, args=(milvus,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
def _test_delete_table_multiprocessing_multitable(self, connect):
|
||||
'''
|
||||
target: test delete table with multiprocess
|
||||
method: create table and then delete,
|
||||
assert the value returned by delete method
|
||||
expected: status ok, and no table in tables
|
||||
'''
|
||||
process_num = 5
|
||||
loop_num = 2
|
||||
processes = []
|
||||
|
||||
table = []
|
||||
j = 0
|
||||
while j < (process_num*loop_num):
|
||||
table_name = gen_unique_str("test_delete_table_with_multiprocessing")
|
||||
table.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
j = j + 1
|
||||
|
||||
def delete(connect,ids):
|
||||
i = 0
|
||||
while i < loop_num:
|
||||
# assert connect.has_table(table[ids*8+i])
|
||||
status = connect.delete_table(table[ids*process_num+i])
|
||||
time.sleep(2)
|
||||
assert status.OK()
|
||||
assert not connect.has_table(table[ids*process_num+i])
|
||||
i = i + 1
|
||||
|
||||
for i in range(process_num):
|
||||
ids = i
|
||||
p = Process(target=delete, args=(connect,ids))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `has_table` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_has_table(self, connect):
|
||||
'''
|
||||
target: test if the created table existed
|
||||
method: create table, assert the value returned by has_table method
|
||||
expected: True
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
assert connect.has_table(table_name)
|
||||
|
||||
def test_has_table_ip(self, connect):
|
||||
'''
|
||||
target: test if the created table existed
|
||||
method: create table, assert the value returned by has_table method
|
||||
expected: True
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
assert connect.has_table(table_name)
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_has_table_without_connection(self, table, dis_connect):
|
||||
'''
|
||||
target: test has table, without connection
|
||||
method: calling has table with correct params, with a disconnected instance
|
||||
expected: has table raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.has_table(table)
|
||||
|
||||
def test_has_table_not_existed(self, connect):
|
||||
'''
|
||||
target: test if table not created
|
||||
method: random a table name, which not existed in db,
|
||||
assert the value returned by has_table method
|
||||
expected: False
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
assert not connect.has_table(table_name)
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `show_tables` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_show_tables(self, connect):
|
||||
'''
|
||||
target: test show tables is correct or not, if table created
|
||||
method: create table, assert the value returned by show_tables method is equal to 0
|
||||
expected: table_name in show tables
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
status, result = connect.show_tables()
|
||||
assert status.OK()
|
||||
assert table_name in result
|
||||
|
||||
def test_show_tables_ip(self, connect):
|
||||
'''
|
||||
target: test show tables is correct or not, if table created
|
||||
method: create table, assert the value returned by show_tables method is equal to 0
|
||||
expected: table_name in show tables
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
status, result = connect.show_tables()
|
||||
assert status.OK()
|
||||
assert table_name in result
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_show_tables_without_connection(self, dis_connect):
|
||||
'''
|
||||
target: test show_tables, without connection
|
||||
method: calling show_tables with correct params, with a disconnected instance
|
||||
expected: show_tables raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.show_tables()
|
||||
|
||||
def test_show_tables_no_table(self, connect):
|
||||
'''
|
||||
target: test show tables is correct or not, if no table in db
|
||||
method: delete all tables,
|
||||
assert the value returned by show_tables method is equal to []
|
||||
expected: the status is ok, and the result is equal to []
|
||||
'''
|
||||
status, result = connect.show_tables()
|
||||
if result:
|
||||
for table_name in result:
|
||||
connect.delete_table(table_name)
|
||||
time.sleep(delete_table_interval_time)
|
||||
status, result = connect.show_tables()
|
||||
assert status.OK()
|
||||
assert len(result) == 0
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
def _test_show_tables_multiprocessing(self, connect, args):
|
||||
'''
|
||||
target: test show tables is correct or not with processes
|
||||
method: create table, assert the value returned by show_tables method is equal to 0
|
||||
expected: table_name in show tables
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
|
||||
def showtables(milvus):
|
||||
status, result = milvus.show_tables()
|
||||
assert status.OK()
|
||||
assert table_name in result
|
||||
|
||||
process_num = 8
|
||||
processes = []
|
||||
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=showtables, args=(milvus,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test `preload_table` function
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
"""
|
||||
generate valid create_index params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table(self, connect, table, get_index_params):
|
||||
index_params = get_index_params
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
status = connect.preload_table(table)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table_ip(self, connect, ip_table, get_index_params):
|
||||
index_params = get_index_params
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
status = connect.preload_table(ip_table)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table_not_existed(self, connect, table):
|
||||
table_name = gen_unique_str("test_preload_table_not_existed")
|
||||
index_params = random.choice(gen_index_params())
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
status = connect.create_index(table, index_params)
|
||||
status = connect.preload_table(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table_not_existed_ip(self, connect, ip_table):
|
||||
table_name = gen_unique_str("test_preload_table_not_existed")
|
||||
index_params = random.choice(gen_index_params())
|
||||
status, ids = connect.add_vectors(ip_table, vectors)
|
||||
status = connect.create_index(ip_table, index_params)
|
||||
status = connect.preload_table(table_name)
|
||||
assert not status.OK()
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table_no_vectors(self, connect, table):
|
||||
status = connect.preload_table(table)
|
||||
assert status.OK()
|
||||
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table_no_vectors_ip(self, connect, ip_table):
|
||||
status = connect.preload_table(ip_table)
|
||||
assert status.OK()
|
||||
|
||||
# TODO: psutils get memory usage
|
||||
@pytest.mark.level(1)
|
||||
def test_preload_table_memory_usage(self, connect, table):
|
||||
pass
|
||||
|
||||
|
||||
class TestTableInvalid(object):
|
||||
"""
|
||||
Test creating table with invalid table names
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_table_names()
|
||||
)
|
||||
def get_table_name(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_create_table_with_invalid_tablename(self, connect, get_table_name):
|
||||
table_name = get_table_name
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
status = connect.create_table(param)
|
||||
assert not status.OK()
|
||||
|
||||
def test_create_table_with_empty_tablename(self, connect):
|
||||
table_name = ''
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_table(param)
|
||||
|
||||
def test_preload_table_with_invalid_tablename(self, connect):
|
||||
table_name = ''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.preload_table(table_name)
|
||||
|
||||
|
||||
class TestCreateTableDimInvalid(object):
|
||||
"""
|
||||
Test creating table with invalid dimension
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_dims()
|
||||
)
|
||||
def get_dim(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_create_table_with_invalid_dimension(self, connect, get_dim):
|
||||
dimension = get_dim
|
||||
table = gen_unique_str("test_create_table_with_invalid_dimension")
|
||||
param = {'table_name': table,
|
||||
'dimension': dimension,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
if isinstance(dimension, int) and dimension > 0:
|
||||
status = connect.create_table(param)
|
||||
assert not status.OK()
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_table(param)
|
||||
|
||||
|
||||
# TODO: max / min index file size
|
||||
class TestCreateTableIndexSizeInvalid(object):
|
||||
"""
|
||||
Test creating tables with invalid index_file_size
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_file_sizes()
|
||||
)
|
||||
def get_file_size(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_table_with_invalid_file_size(self, connect, table, get_file_size):
|
||||
file_size = get_file_size
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_file_size': file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
if isinstance(file_size, int) and file_size > 0:
|
||||
status = connect.create_table(param)
|
||||
assert not status.OK()
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_table(param)
|
||||
|
||||
|
||||
class TestCreateMetricTypeInvalid(object):
|
||||
"""
|
||||
Test creating tables with invalid metric_type
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_metric_types()
|
||||
)
|
||||
def get_metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_create_table_with_invalid_file_size(self, connect, table, get_metric_type):
|
||||
metric_type = get_metric_type
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_file_size': 10,
|
||||
'metric_type': metric_type}
|
||||
with pytest.raises(Exception) as e:
|
||||
status = connect.create_table(param)
|
||||
|
||||
|
||||
def create_table(connect, **params):
|
||||
param = {'table_name': params["table_name"],
|
||||
'dimension': params["dimension"],
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
status = connect.create_table(param)
|
||||
return status
|
||||
|
||||
def search_table(connect, **params):
|
||||
status, result = connect.search_vectors(
|
||||
params["table_name"],
|
||||
params["top_k"],
|
||||
params["nprobe"],
|
||||
params["query_vectors"])
|
||||
return status
|
||||
|
||||
def preload_table(connect, **params):
|
||||
status = connect.preload_table(params["table_name"])
|
||||
return status
|
||||
|
||||
def has(connect, **params):
|
||||
status = connect.has_table(params["table_name"])
|
||||
return status
|
||||
|
||||
def show(connect, **params):
|
||||
status, result = connect.show_tables()
|
||||
return status
|
||||
|
||||
def delete(connect, **params):
|
||||
status = connect.delete_table(params["table_name"])
|
||||
return status
|
||||
|
||||
def describe(connect, **params):
|
||||
status, result = connect.describe_table(params["table_name"])
|
||||
return status
|
||||
|
||||
def rowcount(connect, **params):
|
||||
status, result = connect.get_table_row_count(params["table_name"])
|
||||
return status
|
||||
|
||||
def create_index(connect, **params):
|
||||
status = connect.create_index(params["table_name"], params["index_params"])
|
||||
return status
|
||||
|
||||
func_map = {
|
||||
# 0:has,
|
||||
1:show,
|
||||
10:create_table,
|
||||
11:describe,
|
||||
12:rowcount,
|
||||
13:search_table,
|
||||
14:preload_table,
|
||||
15:create_index,
|
||||
30:delete
|
||||
}
|
||||
|
||||
def gen_sequence():
|
||||
raw_seq = func_map.keys()
|
||||
result = itertools.permutations(raw_seq)
|
||||
for x in result:
|
||||
yield x
|
||||
|
||||
class TestTableLogic(object):
|
||||
|
||||
@pytest.mark.parametrize("logic_seq", gen_sequence())
|
||||
@pytest.mark.level(2)
|
||||
def test_logic(self, connect, logic_seq):
|
||||
if self.is_right(logic_seq):
|
||||
self.execute(logic_seq, connect)
|
||||
else:
|
||||
self.execute_with_error(logic_seq, connect)
|
||||
|
||||
def is_right(self, seq):
|
||||
if sorted(seq) == seq:
|
||||
return True
|
||||
|
||||
not_created = True
|
||||
has_deleted = False
|
||||
for i in range(len(seq)):
|
||||
if seq[i] > 10 and not_created:
|
||||
return False
|
||||
elif seq [i] > 10 and has_deleted:
|
||||
return False
|
||||
elif seq[i] == 10:
|
||||
not_created = False
|
||||
elif seq[i] == 30:
|
||||
has_deleted = True
|
||||
|
||||
return True
|
||||
|
||||
def execute(self, logic_seq, connect):
|
||||
basic_params = self.gen_params()
|
||||
for i in range(len(logic_seq)):
|
||||
# logging.getLogger().info(logic_seq[i])
|
||||
f = func_map[logic_seq[i]]
|
||||
status = f(connect, **basic_params)
|
||||
assert status.OK()
|
||||
|
||||
def execute_with_error(self, logic_seq, connect):
|
||||
basic_params = self.gen_params()
|
||||
|
||||
error_flag = False
|
||||
for i in range(len(logic_seq)):
|
||||
f = func_map[logic_seq[i]]
|
||||
status = f(connect, **basic_params)
|
||||
if not status.OK():
|
||||
# logging.getLogger().info(logic_seq[i])
|
||||
error_flag = True
|
||||
break
|
||||
assert error_flag == True
|
||||
|
||||
def gen_params(self):
|
||||
table_name = gen_unique_str("test_table")
|
||||
top_k = 1
|
||||
vectors = gen_vectors(2, dim)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_type': IndexType.IVFLAT,
|
||||
'metric_type': MetricType.L2,
|
||||
'nprobe': 1,
|
||||
'top_k': top_k,
|
||||
'index_params': {
|
||||
'index_type': IndexType.IVF_SQ8,
|
||||
'nlist': 16384
|
||||
},
|
||||
'query_vectors': vectors}
|
||||
return param
|
|
@ -0,0 +1,296 @@
|
|||
import random
|
||||
import pdb
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
from time import sleep
|
||||
from multiprocessing import Process
|
||||
from milvus import Milvus
|
||||
from utils import *
|
||||
from milvus import IndexType, MetricType
|
||||
|
||||
dim = 128
|
||||
index_file_size = 10
|
||||
add_time_interval = 5
|
||||
|
||||
|
||||
class TestTableCount:
|
||||
"""
|
||||
params means different nb, the nb value may trigger merge, or not
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
100,
|
||||
5000,
|
||||
100000,
|
||||
],
|
||||
)
|
||||
def add_vectors_nb(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
generate valid create_index params
|
||||
"""
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_table_rows_count(self, connect, table, add_vectors_nb):
|
||||
'''
|
||||
target: test table rows_count is correct or not
|
||||
method: create table and add vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to length of vectors
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
nb = add_vectors_nb
|
||||
vectors = gen_vectors(nb, dim)
|
||||
res = connect.add_vectors(table_name=table, records=vectors)
|
||||
time.sleep(add_time_interval)
|
||||
status, res = connect.get_table_row_count(table)
|
||||
assert res == nb
|
||||
|
||||
def test_table_rows_count_after_index_created(self, connect, table, get_index_params):
|
||||
'''
|
||||
target: test get_table_row_count, after index have been created
|
||||
method: add vectors in db, and create index, then calling get_table_row_count with correct params
|
||||
expected: get_table_row_count raise exception
|
||||
'''
|
||||
nb = 100
|
||||
index_params = get_index_params
|
||||
vectors = gen_vectors(nb, dim)
|
||||
res = connect.add_vectors(table_name=table, records=vectors)
|
||||
time.sleep(add_time_interval)
|
||||
# logging.getLogger().info(index_params)
|
||||
connect.create_index(table, index_params)
|
||||
status, res = connect.get_table_row_count(table)
|
||||
assert res == nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_count_without_connection(self, table, dis_connect):
|
||||
'''
|
||||
target: test get_table_row_count, without connection
|
||||
method: calling get_table_row_count with correct params, with a disconnected instance
|
||||
expected: get_table_row_count raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.get_table_row_count(table)
|
||||
|
||||
def test_table_rows_count_no_vectors(self, connect, table):
|
||||
'''
|
||||
target: test table rows_count is correct or not, if table is empty
|
||||
method: create table and no vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to 0
|
||||
expected: the count is equal to 0
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size}
|
||||
connect.create_table(param)
|
||||
status, res = connect.get_table_row_count(table)
|
||||
assert res == 0
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(20)
|
||||
def _test_table_rows_count_multiprocessing(self, connect, table, args):
|
||||
'''
|
||||
target: test table rows_count is correct or not with multiprocess
|
||||
method: create table and add vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to length of vectors
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
nq = 2
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
vectors = gen_vectors(nq, dim)
|
||||
res = connect.add_vectors(table_name=table, records=vectors)
|
||||
time.sleep(add_time_interval)
|
||||
|
||||
def rows_count(milvus):
|
||||
status, res = milvus.get_table_row_count(table)
|
||||
logging.getLogger().info(status)
|
||||
assert res == nq
|
||||
|
||||
process_num = 8
|
||||
processes = []
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=rows_count, args=(milvus, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
logging.getLogger().info(p)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_table_rows_count_multi_tables(self, connect):
|
||||
'''
|
||||
target: test table rows_count is correct or not with multiple tables of L2
|
||||
method: create table and add vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to length of vectors
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
nq = 100
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
for i in range(50):
|
||||
table_name = gen_unique_str('test_table_rows_count_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.L2}
|
||||
connect.create_table(param)
|
||||
res = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
time.sleep(2)
|
||||
for i in range(50):
|
||||
status, res = connect.get_table_row_count(table_list[i])
|
||||
assert status.OK()
|
||||
assert res == nq
|
||||
|
||||
|
||||
class TestTableCountIP:
|
||||
"""
|
||||
params means different nb, the nb value may trigger merge, or not
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
100,
|
||||
5000,
|
||||
100000,
|
||||
],
|
||||
)
|
||||
def add_vectors_nb(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
generate valid create_index params
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_index_params()
|
||||
)
|
||||
def get_index_params(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_table_rows_count(self, connect, ip_table, add_vectors_nb):
|
||||
'''
|
||||
target: test table rows_count is correct or not
|
||||
method: create table and add vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to length of vectors
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
nb = add_vectors_nb
|
||||
vectors = gen_vectors(nb, dim)
|
||||
res = connect.add_vectors(table_name=ip_table, records=vectors)
|
||||
time.sleep(add_time_interval)
|
||||
status, res = connect.get_table_row_count(ip_table)
|
||||
assert res == nb
|
||||
|
||||
def test_table_rows_count_after_index_created(self, connect, ip_table, get_index_params):
|
||||
'''
|
||||
target: test get_table_row_count, after index have been created
|
||||
method: add vectors in db, and create index, then calling get_table_row_count with correct params
|
||||
expected: get_table_row_count raise exception
|
||||
'''
|
||||
nb = 100
|
||||
index_params = get_index_params
|
||||
vectors = gen_vectors(nb, dim)
|
||||
res = connect.add_vectors(table_name=ip_table, records=vectors)
|
||||
time.sleep(add_time_interval)
|
||||
# logging.getLogger().info(index_params)
|
||||
connect.create_index(ip_table, index_params)
|
||||
status, res = connect.get_table_row_count(ip_table)
|
||||
assert res == nb
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_count_without_connection(self, ip_table, dis_connect):
|
||||
'''
|
||||
target: test get_table_row_count, without connection
|
||||
method: calling get_table_row_count with correct params, with a disconnected instance
|
||||
expected: get_table_row_count raise exception
|
||||
'''
|
||||
with pytest.raises(Exception) as e:
|
||||
status = dis_connect.get_table_row_count(ip_table)
|
||||
|
||||
def test_table_rows_count_no_vectors(self, connect, ip_table):
|
||||
'''
|
||||
target: test table rows_count is correct or not, if table is empty
|
||||
method: create table and no vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to 0
|
||||
expected: the count is equal to 0
|
||||
'''
|
||||
table_name = gen_unique_str("test_table")
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size}
|
||||
connect.create_table(param)
|
||||
status, res = connect.get_table_row_count(ip_table)
|
||||
assert res == 0
|
||||
|
||||
# TODO: enable
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(20)
|
||||
def _test_table_rows_count_multiprocessing(self, connect, ip_table, args):
|
||||
'''
|
||||
target: test table rows_count is correct or not with multiprocess
|
||||
method: create table and add vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to length of vectors
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
nq = 2
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
vectors = gen_vectors(nq, dim)
|
||||
res = connect.add_vectors(table_name=ip_table, records=vectors)
|
||||
time.sleep(add_time_interval)
|
||||
|
||||
def rows_count(milvus):
|
||||
status, res = milvus.get_table_row_count(ip_table)
|
||||
logging.getLogger().info(status)
|
||||
assert res == nq
|
||||
|
||||
process_num = 8
|
||||
processes = []
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=rows_count, args=(milvus,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
logging.getLogger().info(p)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_table_rows_count_multi_tables(self, connect):
|
||||
'''
|
||||
target: test table rows_count is correct or not with multiple tables of IP
|
||||
method: create table and add vectors in it,
|
||||
assert the value returned by get_table_row_count method is equal to length of vectors
|
||||
expected: the count is equal to the length of vectors
|
||||
'''
|
||||
nq = 100
|
||||
vectors = gen_vectors(nq, dim)
|
||||
table_list = []
|
||||
for i in range(50):
|
||||
table_name = gen_unique_str('test_table_rows_count_multi_tables')
|
||||
table_list.append(table_name)
|
||||
param = {'table_name': table_name,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': MetricType.IP}
|
||||
connect.create_table(param)
|
||||
res = connect.add_vectors(table_name=table_name, records=vectors)
|
||||
time.sleep(2)
|
||||
for i in range(50):
|
||||
status, res = connect.get_table_row_count(table_list[i])
|
||||
assert status.OK()
|
||||
assert res == nq
|
|
@ -0,0 +1,545 @@
|
|||
# STL imports
|
||||
import random
|
||||
import string
|
||||
import struct
|
||||
import sys
|
||||
import time, datetime
|
||||
import copy
|
||||
import numpy as np
|
||||
from utils import *
|
||||
from milvus import Milvus, IndexType, MetricType
|
||||
|
||||
|
||||
def gen_inaccuracy(num):
|
||||
return num/255.0
|
||||
|
||||
def gen_vectors(num, dim):
|
||||
return [[random.random() for _ in range(dim)] for _ in range(num)]
|
||||
|
||||
|
||||
def gen_single_vector(dim):
|
||||
return [[random.random() for _ in range(dim)]]
|
||||
|
||||
|
||||
def gen_vector(nb, d, seed=np.random.RandomState(1234)):
|
||||
xb = seed.rand(nb, d).astype("float32")
|
||||
return xb.tolist()
|
||||
|
||||
|
||||
def gen_unique_str(str=None):
|
||||
prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
return prefix if str is None else str + "_" + prefix
|
||||
|
||||
|
||||
def get_current_day():
|
||||
return time.strftime('%Y-%m-%d', time.localtime())
|
||||
|
||||
|
||||
def get_last_day(day):
|
||||
tmp = datetime.datetime.now()-datetime.timedelta(days=day)
|
||||
return tmp.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
def get_next_day(day):
|
||||
tmp = datetime.datetime.now()+datetime.timedelta(days=day)
|
||||
return tmp.strftime('%Y-%m-%d')
|
||||
|
||||
|
||||
def gen_long_str(num):
|
||||
string = ''
|
||||
for _ in range(num):
|
||||
char = random.choice('tomorrow')
|
||||
string += char
|
||||
|
||||
|
||||
def gen_invalid_ips():
|
||||
ips = [
|
||||
"255.0.0.0",
|
||||
"255.255.0.0",
|
||||
"255.255.255.0",
|
||||
"255.255.255.255",
|
||||
"127.0.0",
|
||||
"123.0.0.2",
|
||||
"12-s",
|
||||
" ",
|
||||
"12 s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return ips
|
||||
|
||||
|
||||
def gen_invalid_ports():
|
||||
ports = [
|
||||
# empty
|
||||
" ",
|
||||
-1,
|
||||
# too big port
|
||||
100000,
|
||||
# not correct port
|
||||
39540,
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文"
|
||||
]
|
||||
return ports
|
||||
|
||||
|
||||
def gen_invalid_uris():
|
||||
ip = None
|
||||
port = 19530
|
||||
|
||||
uris = [
|
||||
" ",
|
||||
"中文",
|
||||
|
||||
# invalid protocol
|
||||
# "tc://%s:%s" % (ip, port),
|
||||
# "tcp%s:%s" % (ip, port),
|
||||
|
||||
# # invalid port
|
||||
# "tcp://%s:100000" % ip,
|
||||
# "tcp://%s: " % ip,
|
||||
# "tcp://%s:19540" % ip,
|
||||
# "tcp://%s:-1" % ip,
|
||||
# "tcp://%s:string" % ip,
|
||||
|
||||
# invalid ip
|
||||
"tcp:// :%s" % port,
|
||||
"tcp://123.0.0.1:%s" % port,
|
||||
"tcp://127.0.0:%s" % port,
|
||||
"tcp://255.0.0.0:%s" % port,
|
||||
"tcp://255.255.0.0:%s" % port,
|
||||
"tcp://255.255.255.0:%s" % port,
|
||||
"tcp://255.255.255.255:%s" % port,
|
||||
"tcp://\n:%s" % port,
|
||||
|
||||
]
|
||||
return uris
|
||||
|
||||
|
||||
def gen_invalid_table_names():
|
||||
table_names = [
|
||||
"12-s",
|
||||
"12/s",
|
||||
" ",
|
||||
# "",
|
||||
# None,
|
||||
"12 s",
|
||||
"BB。A",
|
||||
"c|c",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return table_names
|
||||
|
||||
|
||||
def gen_invalid_top_ks():
|
||||
top_ks = [
|
||||
0,
|
||||
-1,
|
||||
None,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return top_ks
|
||||
|
||||
|
||||
def gen_invalid_dims():
|
||||
dims = [
|
||||
0,
|
||||
-1,
|
||||
100001,
|
||||
1000000000000001,
|
||||
None,
|
||||
False,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return dims
|
||||
|
||||
|
||||
def gen_invalid_file_sizes():
|
||||
file_sizes = [
|
||||
0,
|
||||
-1,
|
||||
1000000000000001,
|
||||
None,
|
||||
False,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return file_sizes
|
||||
|
||||
|
||||
def gen_invalid_index_types():
|
||||
invalid_types = [
|
||||
0,
|
||||
-1,
|
||||
100,
|
||||
1000000000000001,
|
||||
# None,
|
||||
False,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return invalid_types
|
||||
|
||||
|
||||
def gen_invalid_nlists():
|
||||
nlists = [
|
||||
0,
|
||||
-1,
|
||||
1000000000000001,
|
||||
# None,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文"
|
||||
]
|
||||
return nlists
|
||||
|
||||
|
||||
def gen_invalid_nprobes():
|
||||
nprobes = [
|
||||
0,
|
||||
-1,
|
||||
1000000000000001,
|
||||
None,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文"
|
||||
]
|
||||
return nprobes
|
||||
|
||||
|
||||
def gen_invalid_metric_types():
|
||||
metric_types = [
|
||||
0,
|
||||
-1,
|
||||
1000000000000001,
|
||||
# None,
|
||||
[1,2,3],
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文"
|
||||
]
|
||||
return metric_types
|
||||
|
||||
|
||||
def gen_invalid_vectors():
|
||||
invalid_vectors = [
|
||||
"1*2",
|
||||
[],
|
||||
[1],
|
||||
[1,2],
|
||||
[" "],
|
||||
['a'],
|
||||
[None],
|
||||
None,
|
||||
(1,2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"pip+",
|
||||
"=c",
|
||||
"\n",
|
||||
"\t",
|
||||
"中文",
|
||||
"a".join("a" for i in range(256))
|
||||
]
|
||||
return invalid_vectors
|
||||
|
||||
|
||||
def gen_invalid_vector_ids():
|
||||
invalid_vector_ids = [
|
||||
1.0,
|
||||
-1.0,
|
||||
None,
|
||||
# int 64
|
||||
10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"#12s",
|
||||
"=c",
|
||||
"\n",
|
||||
"中文",
|
||||
]
|
||||
return invalid_vector_ids
|
||||
|
||||
|
||||
|
||||
def gen_invalid_query_ranges():
|
||||
query_ranges = [
|
||||
[(get_last_day(1), "")],
|
||||
[(get_current_day(), "")],
|
||||
[(get_next_day(1), "")],
|
||||
[(get_current_day(), get_last_day(1))],
|
||||
[(get_next_day(1), get_last_day(1))],
|
||||
[(get_next_day(1), get_current_day())],
|
||||
[(0, get_next_day(1))],
|
||||
[(-1, get_next_day(1))],
|
||||
[(1, get_next_day(1))],
|
||||
[(100001, get_next_day(1))],
|
||||
[(1000000000000001, get_next_day(1))],
|
||||
[(None, get_next_day(1))],
|
||||
[([1,2,3], get_next_day(1))],
|
||||
[((1,2), get_next_day(1))],
|
||||
[({"a": 1}, get_next_day(1))],
|
||||
[(" ", get_next_day(1))],
|
||||
[("", get_next_day(1))],
|
||||
[("String", get_next_day(1))],
|
||||
[("12-s", get_next_day(1))],
|
||||
[("BB。A", get_next_day(1))],
|
||||
[(" siede ", get_next_day(1))],
|
||||
[("(mn)", get_next_day(1))],
|
||||
[("#12s", get_next_day(1))],
|
||||
[("pip+", get_next_day(1))],
|
||||
[("=c", get_next_day(1))],
|
||||
[("\n", get_next_day(1))],
|
||||
[("\t", get_next_day(1))],
|
||||
[("中文", get_next_day(1))],
|
||||
[("a".join("a" for i in range(256)), get_next_day(1))]
|
||||
]
|
||||
return query_ranges
|
||||
|
||||
|
||||
def gen_invalid_index_params():
|
||||
index_params = []
|
||||
for index_type in gen_invalid_index_types():
|
||||
index_param = {"index_type": index_type, "nlist": 16384}
|
||||
index_params.append(index_param)
|
||||
for nlist in gen_invalid_nlists():
|
||||
index_param = {"index_type": IndexType.IVFLAT, "nlist": nlist}
|
||||
index_params.append(index_param)
|
||||
return index_params
|
||||
|
||||
|
||||
def gen_index_params():
|
||||
index_params = []
|
||||
index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
|
||||
nlists = [1, 16384, 50000]
|
||||
|
||||
def gen_params(index_types, nlists):
|
||||
return [ {"index_type": index_type, "nlist": nlist} \
|
||||
for index_type in index_types \
|
||||
for nlist in nlists]
|
||||
|
||||
return gen_params(index_types, nlists)
|
||||
|
||||
def gen_simple_index_params():
|
||||
index_params = []
|
||||
index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
|
||||
nlists = [16384]
|
||||
|
||||
def gen_params(index_types, nlists):
|
||||
return [ {"index_type": index_type, "nlist": nlist} \
|
||||
for index_type in index_types \
|
||||
for nlist in nlists]
|
||||
|
||||
return gen_params(index_types, nlists)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy
|
||||
|
||||
dim = 128
|
||||
nq = 10000
|
||||
table = "test"
|
||||
|
||||
|
||||
file_name = '/poc/yuncong/ann_1000m/query.npy'
|
||||
data = np.load(file_name)
|
||||
vectors = data[0:nq].tolist()
|
||||
# print(vectors)
|
||||
|
||||
connect = Milvus()
|
||||
# connect.connect(host="192.168.1.27")
|
||||
# print(connect.show_tables())
|
||||
# print(connect.get_table_row_count(table))
|
||||
# sys.exit()
|
||||
connect.connect(host="127.0.0.1")
|
||||
connect.delete_table(table)
|
||||
# sys.exit()
|
||||
# time.sleep(2)
|
||||
print(connect.get_table_row_count(table))
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'metric_type': MetricType.L2,
|
||||
'index_file_size': 10}
|
||||
status = connect.create_table(param)
|
||||
print(status)
|
||||
print(connect.get_table_row_count(table))
|
||||
# add vectors
|
||||
for i in range(10):
|
||||
status, ids = connect.add_vectors(table, vectors)
|
||||
print(status)
|
||||
print(ids[0])
|
||||
# print(ids[0])
|
||||
index_params = {"index_type": IndexType.IVFLAT, "nlist": 16384}
|
||||
status = connect.create_index(table, index_params)
|
||||
print(status)
|
||||
# sys.exit()
|
||||
query_vec = [vectors[0]]
|
||||
# print(numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0])))
|
||||
top_k = 12
|
||||
nprobe = 1
|
||||
for i in range(2):
|
||||
result = connect.search_vectors(table, top_k, nprobe, query_vec)
|
||||
print(result)
|
||||
sys.exit()
|
||||
|
||||
|
||||
table = gen_unique_str("test_add_vector_with_multiprocessing")
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
param = {'table_name': table,
|
||||
'dimension': dim,
|
||||
'index_file_size': index_file_size}
|
||||
# create table
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
milvus.create_table(param)
|
||||
vector = gen_single_vector(dim)
|
||||
|
||||
process_num = 4
|
||||
loop_num = 10
|
||||
processes = []
|
||||
# with dependent connection
|
||||
def add(milvus):
|
||||
i = 0
|
||||
while i < loop_num:
|
||||
status, ids = milvus.add_vectors(table, vector)
|
||||
i = i + 1
|
||||
for i in range(process_num):
|
||||
milvus = Milvus()
|
||||
milvus.connect(uri=uri)
|
||||
p = Process(target=add, args=(milvus,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
time.sleep(0.2)
|
||||
for p in processes:
|
||||
p.join()
|
||||
time.sleep(3)
|
||||
status, count = milvus.get_table_row_count(table)
|
||||
assert count == process_num * loop_num
|
Loading…
Reference in New Issue