[skip ci] add case for load partition and fix connect case and update concurrent search (#4535)

pull/4563/head
ThreadDao 2020-12-29 15:31:19 +08:00 committed by GitHub
parent 14e23a9238
commit e0d33e5546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 230 additions and 228 deletions

View File

@ -116,7 +116,7 @@
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>0.8.2</version>
<version>0.8.6-SNAPSHOT</version>
</dependency>
<!-- <dependency>-->

View File

@ -34,12 +34,11 @@ public class MainClass {
@DataProvider(name="ConnectInstance")
public Object[][] connectInstance() throws ConnectFailedException {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
.withHost(host)
.withPort(port)
.build();
MilvusClient client = new MilvusGrpcClient(connectParam);
String collectionName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, collectionName}};
}
@ -47,15 +46,15 @@ public class MainClass {
@DataProvider(name="DisConnectInstance")
public Object[][] disConnectInstance() throws ConnectFailedException {
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
.withHost(host)
.withPort(port)
.build();
MilvusClient client = new MilvusGrpcClient(connectParam);
try {
client.disconnect();
} catch (InterruptedException e) {
client.close();
} catch (Exception e) {
e.printStackTrace();
}
String collectionName = RandomStringUtils.randomAlphabetic(10);
@ -69,12 +68,11 @@ public class MainClass {
for (int i = 0; i < metricTypes.length; ++i) {
String collectionName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
.withHost(host)
.withPort(port)
.build();
MilvusClient client = new MilvusGrpcClient(connectParam);
// List<String> tableNames = client.listCollections().getCollectionNames();
// for (int j = 0; j < tableNames.size(); ++j
// ) {
@ -82,9 +80,9 @@ public class MainClass {
// }
// Thread.currentThread().sleep(2000);
CollectionMapping cm = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
Response res = client.createCollection(cm);
if (!res.ok()) {
System.out.println(res.getMessage());
@ -102,12 +100,11 @@ public class MainClass {
for (int i = 0; i < metricTypes.length; ++i) {
String collectionName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
.withHost(host)
.withPort(port)
.build();
MilvusClient client = new MilvusGrpcClient(connectParam);
// List<String> tableNames = client.listCollections().getCollectionNames();
// for (int j = 0; j < tableNames.size(); ++j
// ) {
@ -115,9 +112,9 @@ public class MainClass {
// }
// Thread.currentThread().sleep(2000);
CollectionMapping cm = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
Response res = client.createCollection(cm);
if (!res.ok()) {
System.out.println(res.getMessage());
@ -182,4 +179,4 @@ public class MainClass {
}
}
}

View File

@ -173,4 +173,45 @@ public class TestAddVectors {
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
// test load collection
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_load_partition(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
client.insert(insertParam);
List<String> tags = new ArrayList<>();
tags.add(tag);
Response load_res = client.loadCollection(collectionName, tags);
assert load_res.ok();
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_load_not_existed_partition(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
List<String> tags = new ArrayList<>();
tags.add(tag);
Response load_res = client.loadCollection(collectionName, tags);
assert !load_res.ok();
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_load_empty_partitions(MilvusClient client, String collectionName) {
List<String> tags = new ArrayList<>();
Response load_res = client.loadCollection(collectionName, tags);
assert load_res.ok();
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_load_partitions_after_insert(MilvusClient client, String collectionName) {
List<String> tags = new ArrayList<>();
String tag = RandomStringUtils.randomAlphabetic(10);
tags.add(tag);
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
client.insert(insertParam);
Response load_res = client.loadCollection(collectionName, tags);
assert load_res.ok();
}
}

View File

@ -1,7 +1,8 @@
package com;
import io.milvus.client.*;
import org.testng.Assert;
import io.milvus.client.exception.InitializationException;
import java.util.concurrent.TimeUnit;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
@ -9,80 +10,55 @@ public class TestConnect {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect(String host, int port) throws ConnectFailedException {
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
Response res = client.connect(connectParam);
assert(res.ok());
assert(client.isConnected());
.withHost(host)
.withPort(port)
.build();
new MilvusGrpcClient(connectParam);
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect_repeat(String host, int port) {
MilvusGrpcClient client = new MilvusGrpcClient();
Response res = null;
try {
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
res = client.connect(connectParam);
res = client.connect(connectParam);
} catch (ConnectFailedException e) {
e.printStackTrace();
}
assert (res.ok());
assert(client.isConnected());
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
new MilvusGrpcClient(connectParam);
new MilvusGrpcClient(connectParam);
}
@Test(dataProvider="InvalidConnectArgs")
@Test(dataProvider="InvalidConnectArgs", expectedExceptions = {InitializationException.class, IllegalArgumentException.class})
public void test_connect_invalid_connect_args(String ip, int port) {
MilvusClient client = new MilvusGrpcClient();
Response res = null;
try {
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(ip)
.withPort(port)
.build();
res = client.connect(connectParam);
} catch (Exception e) {
e.printStackTrace();
}
Assert.assertEquals(res, null);
assert(!client.isConnected());
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(ip)
.withPort(port)
.withIdleTimeout(30, TimeUnit.SECONDS)
.build();
new MilvusGrpcClient(connectParam);
}
@DataProvider(name="InvalidConnectArgs")
public Object[][] generate_invalid_connect_args() {
int port = 19530;
return new Object[][]{
{"1.1.1.1", port},
{"255.255.0.0", port},
{"1.2.2", port},
{"中文", port},
{"www.baidu.com", 100000},
{"127.0.0.1", 100000},
{"www.baidu.com", 80},
{"1.1.1.1", port},
{"255.255.0.0", port},
{"1.2.2", port},
{"中文", port},
{"www.baidu.com", 100000},
{"127.0.0.1", 100000},
{"www.baidu.com", 80},
};
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_disconnect(MilvusClient client, String collectionName){
assert(!client.isConnected());
client.close();
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_disconnect_repeatably(MilvusClient client, String collectionName){
Response res = null;
try {
res = client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
assert(!res.ok());
assert(!client.isConnected());
client.close();
client.close();
}
}
}

View File

@ -60,22 +60,12 @@ public class TestMix {
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
try {
client.connect(connectParam);
} catch (ConnectFailedException e) {
e.printStackTrace();
}
assert(client.isConnected());
try {
client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
MilvusClient client = new MilvusGrpcClient(connectParam);
client.close();
});
}
executor.awaitQuiescence(100, TimeUnit.SECONDS);
@ -193,17 +183,11 @@ public class TestMix {
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
try {
client.connect(connectParam);
} catch (ConnectFailedException e) {
e.printStackTrace();
}
assert(client.isConnected());
MilvusClient client = new MilvusGrpcClient(connectParam);
String collectionName = RandomStringUtils.randomAlphabetic(10);
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
@ -214,11 +198,7 @@ public class TestMix {
client.insert(insertParam);
Response response = client.dropCollection(collectionName);
Assert.assertTrue(response.ok());
try {
client.disconnect();
} catch (InterruptedException e) {
e.printStackTrace();
}
client.close();
});
}
executor.awaitQuiescence(100, TimeUnit.SECONDS);

View File

@ -1,133 +1,142 @@
package com;
import io.milvus.client.*;
import org.apache.commons.cli.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public class TestPS {
private static int dimension = 512;
private static String host = "192.168.1.112";
private static String port = "19532";
private static int dimension = 512;
private static String host = "localhost";
private static String port = "19530";
public static void setHost(String host) {
TestPS.host = host;
public static void setHost(String host) {
TestPS.host = host;
}
public static void setPort(String port) {
TestPS.port = port;
}
public static void main(String[] args) {
int nb = 10000;
int nq = 1;
int top_k = 2;
int loops = 100000;
int index_file_size = 1024;
String collectionName = "random_1m_2048_512_ip_sq8";
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
CommandLineParser parser = new DefaultParser();
Options options = new Options();
options.addOption("h", "host", true, "milvus-server hostname/ip");
options.addOption("p", "port", true, "milvus-server port");
try {
CommandLine cmd = parser.parse(options, args);
String host = cmd.getOptionValue("host");
if (host != null) {
setHost(host);
}
String port = cmd.getOptionValue("port");
if (port != null) {
setPort(port);
}
System.out.println("Host: " + host + ", Port: " + port);
} catch (ParseException exp) {
System.err.println("Parsing failed. Reason: " + exp.getMessage());
}
public static void setPort(String port) {
TestPS.port = port;
ConnectParam connectParam =
new ConnectParam.Builder().withHost(host).withPort(Integer.parseInt(port)).build();
MilvusClient client = new MilvusGrpcClient(connectParam);
if (client.hasCollection(collectionName).hasCollection()) {
client.dropCollection(collectionName);
}
public static void main(String[] args) throws ConnectFailedException {
int nb = 10000;
int nq = 1;
int nprobe = 1024;
int top_k = 2;
int loops = 100000000;
// int index_file_size = 1024;
String collectionName = "random_1m_2048_512_ip_sq8";
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
CommandLineParser parser = new DefaultParser();
Options options = new Options();
options.addOption("h", "host", true, "milvus-server hostname/ip");
options.addOption("p", "port", true, "milvus-server port");
try {
CommandLine cmd = parser.parse(options, args);
String host = cmd.getOptionValue("host");
if (host != null) {
setHost(host);
}
String port = cmd.getOptionValue("port");
if (port != null) {
setPort(port);
}
System.out.println("Host: "+host+", Port: "+port);
}
catch(ParseException exp) {
System.err.println("Parsing failed. Reason: " + exp.getMessage() );
}
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(Integer.parseInt(port))
.build();
client.connect(connectParam);
// String collectionName = RandomStringUtils.randomAlphabetic(10);
// TableSchema tableSchema = new TableSchema.Builder(collectionName, dimension)
// .withIndexFileSize(index_file_size)
// .withMetricType(MetricType.IP)
// .build();
// Response res = client.createTable(tableSchema);
// List<Long> vectorIds;
// vectorIds = Stream.iterate(0L, n -> n)
// .limit(nb)
// .collect(Collectors.toList());
// InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withVectorIds(vectorIds).build();
System.setProperty("java.util.concurrent.ForkJoinPool.common.parallelism", "50");
ForkJoinPool executor_search = new ForkJoinPool();
// for (int i = 0; i < loops; i++) {
// List<List<Float>> queryVectors = Utils.genVectors(nq, dimension, true);
// executor_search.execute(
// () -> {
//// InsertResponse res_insert = client.insert(insertParam);
//// assert (res_insert.getResponse().ok());
//// System.out.println("In insert");
// String params = "{\"nprobe\":1024}";
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withFloatVectors(queryVectors)
// .withParamsInJson(params)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (res_search.getResponse().ok());
// });
// }
IntStream.range(0, loops).parallel().forEach(index -> {
List<List<Float>> queryVectors = Utils.genVectors(nq, dimension, true);
String params = "{\"nprobe\":1024}";
SearchParam searchParam = new SearchParam.Builder(collectionName)
.withFloatVectors(queryVectors)
.withParamsInJson(params)
.withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
});
executor_search.awaitQuiescence(300, TimeUnit.SECONDS);
executor_search.shutdown();
CountEntitiesResponse getTableRowCountResponse = client.countEntities(collectionName);
System.out.println(getTableRowCountResponse.getCollectionEntityCount());
// int thread_num = 50;
// ForkJoinPool executor = new ForkJoinPool();
// for (int i = 0; i < thread_num; i++) {
// executor.execute(
// () -> {
// String params = "{\"nprobe\":\"1024\"}";
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withFloatVectors(queryVectors)
// .withParamsInJson(params)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (res_search.getResponse().ok());
// });
// }
// executor.awaitQuiescence(100, TimeUnit.SECONDS);
// executor.shutdown();
// CountEntitiesResponse getTableRowCountResponse = client.countEntities(collectionName);
// System.out.println(getTableRowCountResponse.getCollectionEntityCount());
CollectionMapping tableSchema =
new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
client.createCollection(tableSchema);
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n).limit(nb).collect(Collectors.toList());
InsertParam insertParam =
new InsertParam.Builder(collectionName)
.withFloatVectors(vectors)
.withVectorIds(vectorIds)
.build();
for (int i = 0; i < 100; ++i) {
InsertResponse res_insert = client.insert(insertParam);
assert (res_insert.getResponse().ok());
}
}
System.out.println(client.countEntities(collectionName).getCollectionEntityCount());
ExecutorService executors = Executors.newFixedThreadPool(50);
List<Future> resultList = new ArrayList<Future>();
for (int i = 0; i < loops; i++) {
List<List<Float>> queryVectors = Utils.genVectors(nq, dimension, true);
Future future =
executors.submit(
() -> {
String params = "{\"nprobe\":1024}";
SearchParam searchParam =
new SearchParam.Builder(collectionName)
.withFloatVectors(queryVectors)
.withParamsInJson(params)
.withTopK(top_k)
.build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
});
resultList.add(future);
}
// IntStream.range(0, loops).parallel().forEach(index -> {
// List<List<Float>> queryVectors = Utils.genVectors(nq, dimension,
// true);
// String params = "{\"nprobe\":1024}";
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withFloatVectors(queryVectors)
// .withParamsInJson(params)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (res_search.getResponse().ok());
// });
// executor_search.awaitQuiescence(300, TimeUnit.SECONDS);
executors.shutdown();
CountEntitiesResponse getTableRowCountResponse = client.countEntities(collectionName);
System.out.println(getTableRowCountResponse.getCollectionEntityCount());
for (Future f : resultList) {
try {
System.out.println(f.get());
} catch (Exception e) {
e.printStackTrace();
}
// int thread_num = 50;
// ForkJoinPool executor = new ForkJoinPool();
// for (int i = 0; i < thread_num; i++) {
// executor.execute(
// () -> {
// String params = "{\"nprobe\":\"1024\"}";
// SearchParam searchParam = new SearchParam.Builder(collectionName)
// .withFloatVectors(queryVectors)
// .withParamsInJson(params)
// .withTopK(top_k).build();
// SearchResponse res_search = client.search(searchParam);
// assert (res_search.getResponse().ok());
// });
// }
// executor.awaitQuiescence(100, TimeUnit.SECONDS);
// executor.shutdown();
// CountEntitiesResponse getTableRowCountResponse =
// client.countEntities(collectionName);
// System.out.println(getTableRowCountResponse.getCollectionEntityCount());
}
}
}

View File

@ -7,12 +7,11 @@ public class TestPing {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_server_status(String host, int port) throws ConnectFailedException {
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
MilvusClient client = new MilvusGrpcClient(connectParam);
Response res = client.getServerStatus();
assert (res.ok());
}