Merge pull request #232 from XuPeng-SH/0.6.0

[skip ci] Experimental shards middleware for Milvus
pull/245/head
Jin Hai 2019-11-07 16:10:09 +08:00 committed by GitHub
commit 2756dedd95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 4271 additions and 4 deletions

3
.gitignore vendored
View File

@ -26,3 +26,6 @@ cmake_build
*.lo
*.tar.gz
*.log
.coverage
*.pyc
cov_html/

View File

@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA.
## Feature
- \#12 - Pure CPU version for Milvus
- \#226 - Experimental shards middleware for Milvus
## Improvement
@ -84,7 +85,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-658 - Fix SQ8 Hybrid can't search
- MS-665 - IVF_SQ8H search crash when no GPU resource in search_resources
- \#9 - Change default gpu_cache_capacity to 4
- \#20 - C++ sdk example get grpc error
- \#20 - C++ sdk example get grpc error
- \#23 - Add unittest to improve code coverage
- \#31 - make clang-format failed after run build.sh -l
- \#39 - Create SQ8H index hang if using github server version
@ -136,7 +137,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-635 - Add compile option to support customized faiss
- MS-660 - add ubuntu_build_deps.sh
- \#18 - Add all test cases
# Milvus 0.4.0 (2019-09-12)
## Bug
@ -345,11 +346,11 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-82 - Update server startup welcome message
- MS-83 - Update vecwise to Milvus
- MS-77 - Performance issue of post-search action
- MS-22 - Enhancement for MemVector size control
- MS-22 - Enhancement for MemVector size control
- MS-92 - Unify behavior of debug and release build
- MS-98 - Install all unit test to installation directory
- MS-115 - Change is_startup of metric_config switch from true to on
- MS-122 - Archive criteria config
- MS-122 - Archive criteria config
- MS-124 - HasTable interface
- MS-126 - Add more error code
- MS-128 - Change default db path

13
shards/.dockerignore Normal file
View File

@ -0,0 +1,13 @@
.git
.gitignore
.env
.coverage
.dockerignore
cov_html/
.pytest_cache
__pycache__
*/__pycache__
*.md
*.yml
*.yaml

10
shards/Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM python:3.6
RUN apt update && apt install -y \
less \
telnet
RUN mkdir /source
WORKDIR /source
ADD ./requirements.txt ./
RUN pip install -r requirements.txt
COPY . .
CMD python mishards/main.py

35
shards/Makefile Normal file
View File

@ -0,0 +1,35 @@
HOST=$(or $(host),127.0.0.1)
PORT=$(or $(port),19530)
build:
docker build --network=host -t milvusdb/mishards .
push:
docker push milvusdb/mishards
pull:
docker pull milvusdb/mishards
deploy: clean_deploy
cd all_in_one && docker-compose -f all_in_one.yml up -d && cd -
clean_deploy:
cd all_in_one && docker-compose -f all_in_one.yml down && cd -
probe_deploy:
docker run --rm --name probe --net=host milvusdb/mishards /bin/bash -c "python all_in_one/probe_test.py"
cluster:
cd kubernetes_demo;./start.sh baseup;sleep 10;./start.sh appup;cd -
clean_cluster:
cd kubernetes_demo;./start.sh cleanup;cd -
cluster_status:
kubectl get pods -n milvus -o wide
probe_cluster:
@echo
$(shell kubectl get service -n milvus | grep milvus-proxy-servers | awk {'print $$4,$$5'} | awk -F"[: ]" {'print "docker run --rm --name probe --net=host milvusdb/mishards /bin/bash -c \"python all_in_one/probe_test.py --port="$$2" --host="$$1"\""'})
probe:
docker run --rm --name probe --net=host milvusdb/mishards /bin/bash -c "python all_in_one/probe_test.py --port=${PORT} --host=${HOST}"
clean_coverage:
rm -rf cov_html
clean: clean_coverage clean_deploy clean_cluster
style:
pycodestyle --config=.
coverage:
pytest --cov-report html:cov_html --cov=mishards
test:
pytest

147
shards/Tutorial_CN.md Normal file
View File

@ -0,0 +1,147 @@
# Mishards使用文档
---
Milvus 旨在帮助用户实现海量非结构化数据的近似检索和分析。单个 Milvus 实例可处理十亿级数据规模,而对于百亿或者千亿规模数据的需求,则需要一个 Milvus 集群实例该实例对于上层应用可以像单机实例一样使用同时满足海量数据低延迟高并发业务需求。mishards就是一个集群中间件其内部处理请求转发读写分离水平扩展动态扩容为用户提供内存和算力可以无限扩容的 Milvus 实例。
## 运行环境
---
### 单机快速启动实例
**`python >= 3.4`环境**
```
1. cd milvus/shards
2. pip install -r requirements.txt
3. nvidia-docker run --rm -d -p 19530:19530 -v /tmp/milvus/db:/opt/milvus/db milvusdb/milvus:0.5.0-d102119-ede20b
4. sudo chown -R $USER:$USER /tmp/milvus
5. cp mishards/.env.example mishards/.env
6. 在python mishards/main.py #.env配置mishards监听19532端口
7. make probe port=19532 #健康检查
```
### 容器启动实例
`all_in_one`会在服务器上开启两个milvus实例一个mishards实例一个jaeger链路追踪实例
**启动**
```
cd milvus/shards
1. 安装docker-compose
2. make build
3. make deploy #监听19531端口
4. make clean_deploy #清理服务
5. make probe_deplopy #健康检查
```
**打开Jaeger UI**
```
浏览器打开 "http://127.0.0.1:16686/"
```
### kubernetes中快速启动
**准备**
```
- kubernetes集群
- 安装nvidia-docker
- 共享存储
- 安装kubectl并能访问集群
```
**步骤**
```
cd milvus/shards
1. make deploy_cluster #启动集群
2. make probe_cluster #健康检查
3. make clean_cluster #关闭集群
```
**扩容计算实例**
```
cd milvus/shards/kubernetes_demo/
./start.sh scale-ro-server 2 扩容计算实例到2
```
**扩容代理器实例**
```
cd milvus/shards/kubernetes_demo/
./start.sh scale-proxy 2 扩容代理服务器实例到2
```
**查看日志**
```
kubectl logs -f --tail=1000 -n milvus milvus-ro-servers-0 查看计算节点milvus-ro-servers-0日志
```
## 测试
**启动单元测试**
```
1. cd milvus/shards
2. make test
```
**单元测试覆盖率**
```
1. cd milvus/shards
2. make coverage
```
**代码风格检查**
```
1. cd milvus/shards
2. make style
```
## mishards配置详解
### 全局
| Name | Required | Type | Default Value | Explanation |
| --------------------------- | -------- | -------- | ------------- | ------------- |
| Debug | No | bool | True | 是否Debug工作模式 |
| TIMEZONE | No | string | "UTC" | 时区 |
| MAX_RETRY | No | int | 3 | 最大连接重试次数 |
| SERVER_PORT | No | int | 19530 | 配置服务端口 |
| WOSERVER | **Yes** | str | - | 配置后台可写Milvus实例地址。目前只支持静态设置,例"tcp://127.0.0.1:19530" |
### 元数据
| Name | Required | Type | Default Value | Explanation |
| --------------------------- | -------- | -------- | ------------- | ------------- |
| SQLALCHEMY_DATABASE_URI | **Yes** | string | - | 配置元数据存储数据库地址 |
| SQL_ECHO | No | bool | False | 是否打印Sql详细语句 |
| SQLALCHEMY_DATABASE_TEST_URI | No | string | - | 配置测试环境下元数据存储数据库地址 |
| SQL_TEST_ECHO | No | bool | False | 配置测试环境下是否打印Sql详细语句 |
### 服务发现
| Name | Required | Type | Default Value | Explanation |
| --------------------------- | -------- | -------- | ------------- | ------------- |
| DISCOVERY_PLUGIN_PATH | No | string | - | 用户自定义服务发现插件搜索路径,默认使用系统搜索路径|
| DISCOVERY_CLASS_NAME | No | string | static | 在服务发现插件搜索路径下搜索类并实例化。目前系统提供 **static****kubernetes** 两种类,默认使用 **static** |
| DISCOVERY_STATIC_HOSTS | No | list | [] | **DISCOVERY_CLASS_NAME****static** 时,配置服务地址列表,例"192.168.1.188,192.168.1.190"|
| DISCOVERY_STATIC_PORT | No | int | 19530 | **DISCOVERY_CLASS_NAME****static** 时,配置 Hosts 监听端口 |
| DISCOVERY_KUBERNETES_NAMESPACE | No | string | - | **DISCOVERY_CLASS_NAME****kubernetes** 时,配置集群 namespace |
| DISCOVERY_KUBERNETES_IN_CLUSTER | No | bool | False | **DISCOVERY_CLASS_NAME****kubernetes** 时,标明服务发现是否在集群中运行 |
| DISCOVERY_KUBERNETES_POLL_INTERVAL | No | int | 5 | **DISCOVERY_CLASS_NAME****kubernetes** 时,标明服务发现监听服务列表频率,单位 Second |
| DISCOVERY_KUBERNETES_POD_PATT | No | string | - | **DISCOVERY_CLASS_NAME****kubernetes** 时,匹配可读 Milvus 实例的正则表达式 |
| DISCOVERY_KUBERNETES_LABEL_SELECTOR | No | string | - | **SD_PROVIDER** 为**Kubernetes**时匹配可读Milvus实例的标签选择 |
### 链路追踪
| Name | Required | Type | Default Value | Explanation |
| --------------------------- | -------- | -------- | ------------- | ------------- |
| TRACER_PLUGIN_PATH | No | string | - | 用户自定义链路追踪插件搜索路径,默认使用系统搜索路径|
| TRACER_CLASS_NAME | No | string | "" | 链路追踪方案选择,目前只实现 **Jaeger**, 默认不使用|
| TRACING_SERVICE_NAME | No | string | "mishards" | **TRACING_TYPE****Jaeger** 时,链路追踪服务名 |
| TRACING_SAMPLER_TYPE | No | string | "const" | **TRACING_TYPE****Jaeger** 时,链路追踪采样类型 |
| TRACING_SAMPLER_PARAM | No | int | 1 | **TRACING_TYPE****Jaeger** 时,链路追踪采样频率 |
| TRACING_LOG_PAYLOAD | No | bool | False | **TRACING_TYPE****Jaeger** 时,链路追踪是否采集 Payload |
### 日志
| Name | Required | Type | Default Value | Explanation |
| --------------------------- | -------- | -------- | ------------- | ------------- |
| LOG_LEVEL | No | string | "DEBUG" if Debug is ON else "INFO" | 日志记录级别 |
| LOG_PATH | No | string | "/tmp/mishards" | 日志记录路径 |
| LOG_NAME | No | string | "logfile" | 日志记录名 |
### 路由
| Name | Required | Type | Default Value | Explanation |
| --------------------------- | -------- | -------- | ------------- | ------------- |
| ROUTER_PLUGIN_PATH | No | string | - | 用户自定义路由插件搜索路径,默认使用系统搜索路径|
| ROUTER_CLASS_NAME | No | string | FileBasedHashRingRouter | 处理请求路由类名, 可注册自定义类。目前系统只提供了类 **FileBasedHashRingRouter** |
| ROUTER_CLASS_TEST_NAME | No | string | FileBasedHashRingRouter | 测试环境下处理请求路由类名, 可注册自定义类 |

View File

@ -0,0 +1,53 @@
version: "2.3"
services:
milvus_wr:
runtime: nvidia
restart: always
image: milvusdb/milvus:0.5.0-d102119-ede20b
volumes:
- /tmp/milvus/db:/opt/milvus/db
milvus_ro:
runtime: nvidia
restart: always
image: milvusdb/milvus:0.5.0-d102119-ede20b
volumes:
- /tmp/milvus/db:/opt/milvus/db
- ./ro_server.yml:/opt/milvus/conf/server_config.yaml
jaeger:
restart: always
image: jaegertracing/all-in-one:1.14
ports:
- "0.0.0.0:5775:5775/udp"
- "0.0.0.0:16686:16686"
- "0.0.0.0:9441:9441"
environment:
COLLECTOR_ZIPKIN_HTTP_PORT: 9411
mishards:
restart: always
image: milvusdb/mishards
ports:
- "0.0.0.0:19531:19531"
- "0.0.0.0:19532:19532"
volumes:
- /tmp/milvus/db:/tmp/milvus/db
# - /tmp/mishards_env:/source/mishards/.env
command: ["python", "mishards/main.py"]
environment:
FROM_EXAMPLE: 'true'
DEBUG: 'true'
SERVER_PORT: 19531
WOSERVER: tcp://milvus_wr:19530
DISCOVERY_PLUGIN_PATH: static
DISCOVERY_STATIC_HOSTS: milvus_wr,milvus_ro
TRACER_CLASS_NAME: jaeger
TRACING_SERVICE_NAME: mishards-demo
TRACING_REPORTING_HOST: jaeger
TRACING_REPORTING_PORT: 5775
depends_on:
- milvus_wr
- milvus_ro
- jaeger

View File

@ -0,0 +1,25 @@
from milvus import Milvus
RED = '\033[0;31m'
GREEN = '\033[0;32m'
ENDC = ''
def test(host='127.0.0.1', port=19531):
client = Milvus()
try:
status = client.connect(host=host, port=port)
if status.OK():
print('{}Pass: Connected{}'.format(GREEN, ENDC))
return 0
else:
print('{}Error: {}{}'.format(RED, status, ENDC))
return 1
except Exception as exc:
print('{}Error: {}{}'.format(RED, exc, ENDC))
return 1
if __name__ == '__main__':
import fire
fire.Fire(test)

View File

@ -0,0 +1,41 @@
server_config:
address: 0.0.0.0 # milvus server ip address (IPv4)
port: 19530 # port range: 1025 ~ 65534
deploy_mode: cluster_readonly # deployment type: single, cluster_readonly, cluster_writable
time_zone: UTC+8
db_config:
primary_path: /opt/milvus # path used to store data and meta
secondary_path: # path used to store data only, split by semicolon
backend_url: sqlite://:@:/ # URI format: dialect://username:password@host:port/database
# Keep 'dialect://:@:/', and replace other texts with real values
# Replace 'dialect' with 'mysql' or 'sqlite'
insert_buffer_size: 4 # GB, maximum insert buffer size allowed
# sum of insert_buffer_size and cpu_cache_capacity cannot exceed total memory
preload_table: # preload data at startup, '*' means load all tables, empty value means no preload
# you can specify preload tables like this: table1,table2,table3
metric_config:
enable_monitor: false # enable monitoring or not
collector: prometheus # prometheus
prometheus_config:
port: 8080 # port prometheus uses to fetch metrics
cache_config:
cpu_cache_capacity: 16 # GB, CPU memory used for cache
cpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered
gpu_cache_capacity: 4 # GB, GPU memory used for cache
gpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered
cache_insert_data: false # whether to load inserted data into cache
engine_config:
use_blas_threshold: 20 # if nq < use_blas_threshold, use SSE, faster with fluctuated response times
# if nq >= use_blas_threshold, use OpenBlas, slower with stable response times
resource_config:
search_resources: # define the GPUs used for search computation, valid value: gpux
- gpu0
index_build_device: gpu0 # GPU used for building index

39
shards/conftest.py Normal file
View File

@ -0,0 +1,39 @@
import os
import logging
import pytest
import grpc
import tempfile
import shutil
from mishards import settings, db, create_app
logger = logging.getLogger(__name__)
tpath = tempfile.mkdtemp()
dirpath = '{}/db'.format(tpath)
filepath = '{}/meta.sqlite'.format(dirpath)
os.makedirs(dirpath, 0o777)
settings.TestingConfig.SQLALCHEMY_DATABASE_URI = 'sqlite:///{}?check_same_thread=False'.format(
filepath)
@pytest.fixture
def app(request):
app = create_app(settings.TestingConfig)
db.drop_all()
db.create_all()
yield app
db.drop_all()
app.stop()
# shutil.rmtree(tpath)
@pytest.fixture
def started_app(app):
app.on_pre_run()
app.start(settings.SERVER_TEST_PORT)
yield app
app.stop()

View File

@ -0,0 +1,37 @@
import os
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))))
import logging
from utils import dotdict
logger = logging.getLogger(__name__)
class DiscoveryConfig(dotdict):
CONFIG_PREFIX = 'DISCOVERY_'
def dump(self):
logger.info('----------- DiscoveryConfig -----------------')
for k, v in self.items():
logger.info('{}: {}'.format(k, v))
if len(self) <= 0:
logger.error(' Empty DiscoveryConfig Found! ')
logger.info('---------------------------------------------')
@classmethod
def Create(cls, **kwargs):
o = cls()
for k, v in os.environ.items():
if not k.startswith(cls.CONFIG_PREFIX):
continue
o[k] = v
for k, v in kwargs.items():
o[k] = v
o.dump()
return o

View File

@ -0,0 +1,22 @@
import logging
from discovery import DiscoveryConfig
from utils.plugins import BaseMixin
logger = logging.getLogger(__name__)
PLUGIN_PACKAGE_NAME = 'discovery.plugins'
class DiscoveryFactory(BaseMixin):
PLUGIN_TYPE = 'Discovery'
def __init__(self, searchpath=None):
super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME)
def _create(self, plugin_class, **kwargs):
conn_mgr = kwargs.pop('conn_mgr', None)
if not conn_mgr:
raise RuntimeError('Please pass conn_mgr to create discovery!')
plugin_config = DiscoveryConfig.Create()
plugin = plugin_class.Create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs)
return plugin

View File

@ -0,0 +1,346 @@
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))))
import re
import logging
import time
import copy
import threading
import queue
import enum
from kubernetes import client, config, watch
logger = logging.getLogger(__name__)
INCLUSTER_NAMESPACE_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/namespace'
class EventType(enum.Enum):
PodHeartBeat = 1
Watch = 2
class K8SMixin:
def __init__(self, namespace, in_cluster=False, **kwargs):
self.namespace = namespace
self.in_cluster = in_cluster
self.kwargs = kwargs
self.v1 = kwargs.get('v1', None)
if not self.namespace:
self.namespace = open(INCLUSTER_NAMESPACE_PATH).read()
if not self.v1:
config.load_incluster_config(
) if self.in_cluster else config.load_kube_config()
self.v1 = client.CoreV1Api()
class K8SHeartbeatHandler(threading.Thread, K8SMixin):
name = 'kubernetes'
def __init__(self,
message_queue,
namespace,
label_selector,
in_cluster=False,
**kwargs):
K8SMixin.__init__(self,
namespace=namespace,
in_cluster=in_cluster,
**kwargs)
threading.Thread.__init__(self)
self.queue = message_queue
self.terminate = False
self.label_selector = label_selector
self.poll_interval = kwargs.get('poll_interval', 5)
def run(self):
while not self.terminate:
try:
pods = self.v1.list_namespaced_pod(
namespace=self.namespace,
label_selector=self.label_selector)
event_message = {'eType': EventType.PodHeartBeat, 'events': []}
for item in pods.items:
pod = self.v1.read_namespaced_pod(name=item.metadata.name,
namespace=self.namespace)
name = pod.metadata.name
ip = pod.status.pod_ip
phase = pod.status.phase
reason = pod.status.reason
message = pod.status.message
ready = True if phase == 'Running' else False
pod_event = dict(pod=name,
ip=ip,
ready=ready,
reason=reason,
message=message)
event_message['events'].append(pod_event)
self.queue.put(event_message)
except Exception as exc:
logger.error(exc)
time.sleep(self.poll_interval)
def stop(self):
self.terminate = True
class K8SEventListener(threading.Thread, K8SMixin):
def __init__(self, message_queue, namespace, in_cluster=False, **kwargs):
K8SMixin.__init__(self,
namespace=namespace,
in_cluster=in_cluster,
**kwargs)
threading.Thread.__init__(self)
self.queue = message_queue
self.terminate = False
self.at_start_up = True
self._stop_event = threading.Event()
def stop(self):
self.terminate = True
self._stop_event.set()
def run(self):
resource_version = ''
w = watch.Watch()
for event in w.stream(self.v1.list_namespaced_event,
namespace=self.namespace,
field_selector='involvedObject.kind=Pod'):
if self.terminate:
break
resource_version = int(event['object'].metadata.resource_version)
info = dict(
eType=EventType.Watch,
pod=event['object'].involved_object.name,
reason=event['object'].reason,
message=event['object'].message,
start_up=self.at_start_up,
)
self.at_start_up = False
# logger.info('Received event: {}'.format(info))
self.queue.put(info)
class EventHandler(threading.Thread):
def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs):
threading.Thread.__init__(self)
self.mgr = mgr
self.queue = message_queue
self.kwargs = kwargs
self.terminate = False
self.pod_patt = re.compile(pod_patt)
self.namespace = namespace
def stop(self):
self.terminate = True
def on_drop(self, event, **kwargs):
pass
def on_pod_started(self, event, **kwargs):
try_cnt = 3
pod = None
while try_cnt > 0:
try_cnt -= 1
try:
pod = self.mgr.v1.read_namespaced_pod(name=event['pod'],
namespace=self.namespace)
if not pod.status.pod_ip:
time.sleep(0.5)
continue
break
except client.rest.ApiException as exc:
time.sleep(0.5)
if try_cnt <= 0 and not pod:
if not event['start_up']:
logger.error('Pod {} is started but cannot read pod'.format(
event['pod']))
return
elif try_cnt <= 0 and not pod.status.pod_ip:
logger.warning('NoPodIPFoundError')
return
logger.info('Register POD {} with IP {}'.format(
pod.metadata.name, pod.status.pod_ip))
self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip)
def on_pod_killing(self, event, **kwargs):
logger.info('Unregister POD {}'.format(event['pod']))
self.mgr.delete_pod(name=event['pod'])
def on_pod_heartbeat(self, event, **kwargs):
names = self.mgr.conn_mgr.conn_names
running_names = set()
for each_event in event['events']:
if each_event['ready']:
self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip'])
running_names.add(each_event['pod'])
else:
self.mgr.delete_pod(name=each_event['pod'])
to_delete = names - running_names
for name in to_delete:
self.mgr.delete_pod(name)
logger.info(self.mgr.conn_mgr.conn_names)
def handle_event(self, event):
if event['eType'] == EventType.PodHeartBeat:
return self.on_pod_heartbeat(event)
if not event or (event['reason'] not in ('Started', 'Killing')):
return self.on_drop(event)
if not re.match(self.pod_patt, event['pod']):
return self.on_drop(event)
logger.info('Handling event: {}'.format(event))
if event['reason'] == 'Started':
return self.on_pod_started(event)
return self.on_pod_killing(event)
def run(self):
while not self.terminate:
try:
event = self.queue.get(timeout=1)
self.handle_event(event)
except queue.Empty:
continue
class KubernetesProviderSettings:
def __init__(self, namespace, pod_patt, label_selector, in_cluster,
poll_interval, port=None, **kwargs):
self.namespace = namespace
self.pod_patt = pod_patt
self.label_selector = label_selector
self.in_cluster = in_cluster
self.poll_interval = poll_interval
self.port = int(port) if port else 19530
class KubernetesProvider(object):
name = 'kubernetes'
def __init__(self, plugin_config, conn_mgr, **kwargs):
self.namespace = plugin_config.DISCOVERY_KUBERNETES_NAMESPACE
self.pod_patt = plugin_config.DISCOVERY_KUBERNETES_POD_PATT
self.label_selector = plugin_config.DISCOVERY_KUBERNETES_LABEL_SELECTOR
self.in_cluster = plugin_config.DISCOVERY_KUBERNETES_IN_CLUSTER.lower()
self.in_cluster = self.in_cluster == 'true'
self.poll_interval = plugin_config.DISCOVERY_KUBERNETES_POLL_INTERVAL
self.poll_interval = int(self.poll_interval) if self.poll_interval else 5
self.port = plugin_config.DISCOVERY_KUBERNETES_PORT
self.port = int(self.port) if self.port else 19530
self.kwargs = kwargs
self.queue = queue.Queue()
self.conn_mgr = conn_mgr
if not self.namespace:
self.namespace = open(incluster_namespace_path).read()
config.load_incluster_config(
) if self.in_cluster else config.load_kube_config()
self.v1 = client.CoreV1Api()
self.listener = K8SEventListener(message_queue=self.queue,
namespace=self.namespace,
in_cluster=self.in_cluster,
v1=self.v1,
**kwargs)
self.pod_heartbeater = K8SHeartbeatHandler(
message_queue=self.queue,
namespace=self.namespace,
label_selector=self.label_selector,
in_cluster=self.in_cluster,
v1=self.v1,
poll_interval=self.poll_interval,
**kwargs)
self.event_handler = EventHandler(mgr=self,
message_queue=self.queue,
namespace=self.namespace,
pod_patt=self.pod_patt,
**kwargs)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
def delete_pod(self, name):
self.conn_mgr.unregister(name)
def start(self):
self.listener.daemon = True
self.listener.start()
self.event_handler.start()
self.pod_heartbeater.start()
def stop(self):
self.listener.stop()
self.pod_heartbeater.stop()
self.event_handler.stop()
@classmethod
def Create(cls, conn_mgr, plugin_config, **kwargs):
discovery = cls(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs)
return discovery
def setup(app):
logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name))
app.on_plugin_setup(KubernetesProvider)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(
os.path.abspath(__file__)))))
class Connect:
def register(self, name, value):
logger.error('Register: {} - {}'.format(name, value))
def unregister(self, name):
logger.error('Unregister: {}'.format(name))
@property
def conn_names(self):
return set()
connect_mgr = Connect()
from discovery import DiscoveryConfig
settings = DiscoveryConfig(DISCOVERY_KUBERNETES_NAMESPACE='xp',
DISCOVERY_KUBERNETES_POD_PATT=".*-ro-servers-.*",
DISCOVERY_KUBERNETES_LABEL_SELECTOR='tier=ro-servers',
DISCOVERY_KUBERNETES_POLL_INTERVAL=5,
DISCOVERY_KUBERNETES_IN_CLUSTER=False)
provider_class = KubernetesProvider
t = provider_class(conn_mgr=connect_mgr, plugin_config=settings)
t.start()
cnt = 100
while cnt > 0:
time.sleep(2)
cnt -= 1
t.stop()

View File

@ -0,0 +1,45 @@
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import logging
import socket
from environs import Env
logger = logging.getLogger(__name__)
env = Env()
class StaticDiscovery(object):
name = 'static'
def __init__(self, config, conn_mgr, **kwargs):
self.conn_mgr = conn_mgr
hosts = env.list('DISCOVERY_STATIC_HOSTS', [])
self.port = env.int('DISCOVERY_STATIC_PORT', 19530)
self.hosts = [socket.gethostbyname(host) for host in hosts]
def start(self):
for host in self.hosts:
self.add_pod(host, host)
def stop(self):
for host in self.hosts:
self.delete_pod(host)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
def delete_pod(self, name):
self.conn_mgr.unregister(name)
@classmethod
def Create(cls, conn_mgr, plugin_config, **kwargs):
discovery = cls(config=plugin_config, conn_mgr=conn_mgr, **kwargs)
return discovery
def setup(app):
logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name))
app.on_plugin_setup(StaticDiscovery)

View File

@ -0,0 +1,67 @@
kind: Service
apiVersion: v1
metadata:
name: milvus-mysql
namespace: milvus
spec:
type: ClusterIP
selector:
app: milvus
tier: mysql
ports:
- protocol: TCP
port: 3306
targetPort: 3306
name: mysql
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: milvus-mysql
namespace: milvus
spec:
selector:
matchLabels:
app: milvus
tier: mysql
replicas: 1
template:
metadata:
labels:
app: milvus
tier: mysql
spec:
containers:
- name: milvus-mysql
image: mysql:5.7
imagePullPolicy: IfNotPresent
# lifecycle:
# postStart:
# exec:
# command: ["/bin/sh", "-c", "mysql -h milvus-mysql -uroot -p${MYSQL_ROOT_PASSWORD} -e \"CREATE DATABASE IF NOT EXISTS ${DATABASE};\"; \
# mysql -uroot -p${MYSQL_ROOT_PASSWORD} -e \"GRANT ALL PRIVILEGES ON ${DATABASE}.* TO 'root'@'%';\""]
env:
- name: MYSQL_ROOT_PASSWORD
value: milvusroot
- name: DATABASE
value: milvus
ports:
- name: mysql-port
containerPort: 3306
volumeMounts:
- name: milvus-mysql-disk
mountPath: /data
subPath: mysql
- name: milvus-mysql-configmap
mountPath: /etc/mysql/mysql.conf.d/mysqld.cnf
subPath: milvus_mysql_config.yml
volumes:
- name: milvus-mysql-disk
persistentVolumeClaim:
claimName: milvus-mysql-disk
- name: milvus-mysql-configmap
configMap:
name: milvus-mysql-configmap

View File

@ -0,0 +1,185 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: milvus-mysql-configmap
namespace: milvus
data:
milvus_mysql_config.yml: |
[mysqld]
pid-file = /var/run/mysqld/mysqld.pid
socket = /var/run/mysqld/mysqld.sock
datadir = /data
log-error = /var/log/mysql/error.log # mount out to host
# By default we only accept connections from localhost
bind-address = 0.0.0.0
# Disabling symbolic-links is recommended to prevent assorted security risks
symbolic-links=0
character-set-server = utf8mb4
collation-server = utf8mb4_unicode_ci
init_connect='SET NAMES utf8mb4'
skip-character-set-client-handshake = true
max_connections = 1000
wait_timeout = 31536000
---
apiVersion: v1
kind: ConfigMap
metadata:
name: milvus-proxy-configmap
namespace: milvus
data:
milvus_proxy_config.yml: |
DEBUG=True
TESTING=False
WOSERVER=tcp://milvus-wo-servers:19530
SERVER_PORT=19530
DISCOVERY_CLASS_NAME=kubernetes
DISCOVERY_KUBERNETES_NAMESPACE=milvus
DISCOVERY_KUBERNETES_POD_PATT=.*-ro-servers-.*
DISCOVERY_KUBERNETES_LABEL_SELECTOR=tier=ro-servers
DISCOVERY_KUBERNETES_POLL_INTERVAL=10
DISCOVERY_KUBERNETES_IN_CLUSTER=True
SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:milvusroot@milvus-mysql:3306/milvus?charset=utf8mb4
SQLALCHEMY_POOL_SIZE=50
SQLALCHEMY_POOL_RECYCLE=7200
LOG_PATH=/var/log/milvus
TIMEZONE=Asia/Shanghai
---
apiVersion: v1
kind: ConfigMap
metadata:
name: milvus-roserver-configmap
namespace: milvus
data:
config.yml: |
server_config:
address: 0.0.0.0
port: 19530
mode: cluster_readonly
db_config:
primary_path: /var/milvus
backend_url: mysql://root:milvusroot@milvus-mysql:3306/milvus
insert_buffer_size: 2
metric_config:
enable_monitor: off # true is on, false is off
cache_config:
cpu_cache_capacity: 12 # memory pool to hold index data, unit: GB
cpu_cache_free_percent: 0.85
insert_cache_immediately: false
# gpu_cache_capacity: 4
# gpu_cache_free_percent: 0.85
# gpu_ids:
# - 0
engine_config:
use_blas_threshold: 800
resource_config:
search_resources:
- gpu0
log.conf: |
* GLOBAL:
FORMAT = "%datetime | %level | %logger | %msg"
FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-global.log"
ENABLED = true
TO_FILE = true
TO_STANDARD_OUTPUT = true
SUBSECOND_PRECISION = 3
PERFORMANCE_TRACKING = false
MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB
* DEBUG:
FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-debug.log"
ENABLED = true
* WARNING:
FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-warning.log"
* TRACE:
FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-trace.log"
* VERBOSE:
FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
TO_FILE = true
TO_STANDARD_OUTPUT = true
## Error logs
* ERROR:
ENABLED = true
FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-error.log"
* FATAL:
ENABLED = true
FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-fatal.log"
---
apiVersion: v1
kind: ConfigMap
metadata:
name: milvus-woserver-configmap
namespace: milvus
data:
config.yml: |
server_config:
address: 0.0.0.0
port: 19530
mode: cluster_writable
db_config:
primary_path: /var/milvus
backend_url: mysql://root:milvusroot@milvus-mysql:3306/milvus
insert_buffer_size: 2
metric_config:
enable_monitor: off # true is on, false is off
cache_config:
cpu_cache_capacity: 2 # memory pool to hold index data, unit: GB
cpu_cache_free_percent: 0.85
insert_cache_immediately: false
# gpu_cache_capacity: 4
# gpu_cache_free_percent: 0.85
# gpu_ids:
# - 0
engine_config:
use_blas_threshold: 800
resource_config:
search_resources:
- gpu0
log.conf: |
* GLOBAL:
FORMAT = "%datetime | %level | %logger | %msg"
FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-global.log"
ENABLED = true
TO_FILE = true
TO_STANDARD_OUTPUT = true
SUBSECOND_PRECISION = 3
PERFORMANCE_TRACKING = false
MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB
* DEBUG:
FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-debug.log"
ENABLED = true
* WARNING:
FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-warning.log"
* TRACE:
FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-trace.log"
* VERBOSE:
FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
TO_FILE = true
TO_STANDARD_OUTPUT = true
## Error logs
* ERROR:
ENABLED = true
FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-error.log"
* FATAL:
ENABLED = true
FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-fatal.log"

View File

@ -0,0 +1,57 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: milvus-db-disk
namespace: milvus
spec:
accessModes:
- ReadWriteMany
storageClassName: default
resources:
requests:
storage: 50Gi
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: milvus-log-disk
namespace: milvus
spec:
accessModes:
- ReadWriteMany
storageClassName: default
resources:
requests:
storage: 50Gi
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: milvus-mysql-disk
namespace: milvus
spec:
accessModes:
- ReadWriteMany
storageClassName: default
resources:
requests:
storage: 50Gi
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: milvus-redis-disk
namespace: milvus
spec:
accessModes:
- ReadWriteOnce
storageClassName: default
resources:
requests:
storage: 5Gi

View File

@ -0,0 +1,88 @@
kind: Service
apiVersion: v1
metadata:
name: milvus-proxy-servers
namespace: milvus
spec:
type: LoadBalancer
selector:
app: milvus
tier: proxy
ports:
- name: tcp
protocol: TCP
port: 19530
targetPort: 19530
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: milvus-proxy
namespace: milvus
spec:
selector:
matchLabels:
app: milvus
tier: proxy
replicas: 1
template:
metadata:
labels:
app: milvus
tier: proxy
spec:
containers:
- name: milvus-proxy
image: milvusdb/mishards:0.1.0-rc0
imagePullPolicy: Always
command: ["python", "mishards/main.py"]
resources:
limits:
memory: "3Gi"
cpu: "4"
requests:
memory: "2Gi"
ports:
- name: tcp
containerPort: 5000
env:
# - name: SQL_ECHO
# value: "True"
- name: DEBUG
value: "False"
- name: POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: MILVUS_CLIENT
value: "False"
- name: LOG_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: LOG_PATH
value: /var/log/milvus
- name: SD_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: SD_ROSERVER_POD_PATT
value: ".*-ro-servers-.*"
volumeMounts:
- name: milvus-proxy-configmap
mountPath: /source/mishards/.env
subPath: milvus_proxy_config.yml
- name: milvus-log-disk
mountPath: /var/log/milvus
subPath: proxylog
# imagePullSecrets:
# - name: regcred
volumes:
- name: milvus-proxy-configmap
configMap:
name: milvus-proxy-configmap
- name: milvus-log-disk
persistentVolumeClaim:
claimName: milvus-log-disk

View File

@ -0,0 +1,24 @@
kind: ClusterRole
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: pods-list
rules:
- apiGroups: [""]
resources: ["pods", "events"]
verbs: ["list", "get", "watch"]
---
kind: ClusterRoleBinding
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: pods-list
subjects:
- kind: ServiceAccount
name: default
namespace: milvus
roleRef:
kind: ClusterRole
name: pods-list
apiGroup: rbac.authorization.k8s.io
---

View File

@ -0,0 +1,68 @@
kind: Service
apiVersion: v1
metadata:
name: milvus-ro-servers
namespace: milvus
spec:
type: ClusterIP
selector:
app: milvus
tier: ro-servers
ports:
- protocol: TCP
port: 19530
targetPort: 19530
---
apiVersion: apps/v1beta1
kind: StatefulSet
metadata:
name: milvus-ro-servers
namespace: milvus
spec:
serviceName: "milvus-ro-servers"
replicas: 1
template:
metadata:
labels:
app: milvus
tier: ro-servers
spec:
terminationGracePeriodSeconds: 11
containers:
- name: milvus-ro-server
image: milvusdb/milvus:0.5.0-d102119-ede20b
imagePullPolicy: Always
ports:
- containerPort: 19530
resources:
limits:
memory: "16Gi"
cpu: "8.0"
requests:
memory: "14Gi"
volumeMounts:
- name: milvus-db-disk
mountPath: /var/milvus
subPath: dbdata
- name: milvus-roserver-configmap
mountPath: /opt/milvus/conf/server_config.yaml
subPath: config.yml
- name: milvus-roserver-configmap
mountPath: /opt/milvus/conf/log_config.conf
subPath: log.conf
# imagePullSecrets:
# - name: regcred
# tolerations:
# - key: "worker"
# operator: "Equal"
# value: "performance"
# effect: "NoSchedule"
volumes:
- name: milvus-roserver-configmap
configMap:
name: milvus-roserver-configmap
- name: milvus-db-disk
persistentVolumeClaim:
claimName: milvus-db-disk

View File

@ -0,0 +1,70 @@
kind: Service
apiVersion: v1
metadata:
name: milvus-wo-servers
namespace: milvus
spec:
type: ClusterIP
selector:
app: milvus
tier: wo-servers
ports:
- protocol: TCP
port: 19530
targetPort: 19530
---
apiVersion: apps/v1beta1
kind: Deployment
metadata:
name: milvus-wo-servers
namespace: milvus
spec:
selector:
matchLabels:
app: milvus
tier: wo-servers
replicas: 1
template:
metadata:
labels:
app: milvus
tier: wo-servers
spec:
containers:
- name: milvus-wo-server
image: milvusdb/milvus:0.5.0-d102119-ede20b
imagePullPolicy: Always
ports:
- containerPort: 19530
resources:
limits:
memory: "5Gi"
cpu: "1.0"
requests:
memory: "4Gi"
volumeMounts:
- name: milvus-db-disk
mountPath: /var/milvus
subPath: dbdata
- name: milvus-woserver-configmap
mountPath: /opt/milvus/conf/server_config.yaml
subPath: config.yml
- name: milvus-woserver-configmap
mountPath: /opt/milvus/conf/log_config.conf
subPath: log.conf
# imagePullSecrets:
# - name: regcred
# tolerations:
# - key: "worker"
# operator: "Equal"
# value: "performance"
# effect: "NoSchedule"
volumes:
- name: milvus-woserver-configmap
configMap:
name: milvus-woserver-configmap
- name: milvus-db-disk
persistentVolumeClaim:
claimName: milvus-db-disk

368
shards/kubernetes_demo/start.sh Executable file
View File

@ -0,0 +1,368 @@
#!/bin/bash
UL=`tput smul`
NOUL=`tput rmul`
BOLD=`tput bold`
NORMAL=`tput sgr0`
RED='\033[0;31m'
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
ENDC='\033[0m'
function showHelpMessage () {
echo -e "${BOLD}Usage:${NORMAL} ${RED}$0${ENDC} [option...] {cleanup${GREEN}|${ENDC}baseup${GREEN}|${ENDC}appup${GREEN}|${ENDC}appdown${GREEN}|${ENDC}allup}" >&2
echo
echo " -h, --help show help message"
echo " ${BOLD}cleanup, delete all resources${NORMAL}"
echo " ${BOLD}baseup, start all required base resources${NORMAL}"
echo " ${BOLD}appup, start all pods${NORMAL}"
echo " ${BOLD}appdown, remove all pods${NORMAL}"
echo " ${BOLD}allup, start all base resources and pods${NORMAL}"
echo " ${BOLD}scale-proxy, scale proxy${NORMAL}"
echo " ${BOLD}scale-ro-server, scale readonly servers${NORMAL}"
echo " ${BOLD}scale-worker, scale calculation workers${NORMAL}"
}
function showscaleHelpMessage () {
echo -e "${BOLD}Usage:${NORMAL} ${RED}$0 $1${ENDC} [option...] {1|2|3|4|...}" >&2
echo
echo " -h, --help show help message"
echo " ${BOLD}number, (int) target scale number"
}
function PrintScaleSuccessMessage() {
echo -e "${BLUE}${BOLD}Successfully Scaled: ${1} --> ${2}${ENDC}"
}
function PrintPodStatusMessage() {
echo -e "${BOLD}${1}${NORMAL}"
}
timeout=60
function setUpMysql () {
mysqlUserName=$(kubectl describe configmap -n milvus milvus-roserver-configmap |
grep backend_url |
awk '{print $2}' |
awk '{split($0, level1, ":");
split(level1[2], level2, "/");
print level2[3]}')
mysqlPassword=$(kubectl describe configmap -n milvus milvus-roserver-configmap |
grep backend_url |
awk '{print $2}' |
awk '{split($0, level1, ":");
split(level1[3], level3, "@");
print level3[1]}')
mysqlDBName=$(kubectl describe configmap -n milvus milvus-roserver-configmap |
grep backend_url |
awk '{print $2}' |
awk '{split($0, level1, ":");
split(level1[4], level4, "/");
print level4[2]}')
mysqlContainer=$(kubectl get pods -n milvus | grep milvus-mysql | awk '{print $1}')
kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "CREATE DATABASE IF NOT EXISTS $mysqlDBName;"
checkDBExists=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '$mysqlDBName';" | grep -o $mysqlDBName | wc -l)
counter=0
while [ $checkDBExists -lt 1 ]; do
sleep 1
let counter=counter+1
if [ $counter == $timeout ]; then
echo "Creating MySQL database $mysqlDBName timeout"
return 1
fi
checkDBExists=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '$mysqlDBName';" | grep -o $mysqlDBName | wc -l)
done;
kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "GRANT ALL PRIVILEGES ON $mysqlDBName.* TO '$mysqlUserName'@'%';"
kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "FLUSH PRIVILEGES;"
checkGrant=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SHOW GRANTS for $mysqlUserName;" | grep -o "GRANT ALL PRIVILEGES ON \`$mysqlDBName\`\.\*" | wc -l)
counter=0
while [ $checkGrant -lt 1 ]; do
sleep 1
let counter=counter+1
if [ $counter == $timeout ]; then
echo "Granting all privileges on $mysqlDBName to $mysqlUserName timeout"
return 1
fi
checkGrant=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SHOW GRANTS for $mysqlUserName;" | grep -o "GRANT ALL PRIVILEGES ON \`$mysqlDBName\`\.\*" | wc -l)
done;
}
function checkStatefulSevers() {
stateful_replicas=$(kubectl describe statefulset -n milvus milvus-ro-servers | grep "Replicas:" | awk '{print $2}')
stateful_running_pods=$(kubectl describe statefulset -n milvus milvus-ro-servers | grep "Pods Status:" | awk '{print $3}')
counter=0
prev=$stateful_running_pods
PrintPodStatusMessage "Running milvus-ro-servers Pods: $stateful_running_pods/$stateful_replicas"
while [ $stateful_replicas != $stateful_running_pods ]; do
echo -e "${YELLOW}Wait another 1 sec --- ${counter}${ENDC}"
sleep 1;
let counter=counter+1
if [ $counter -eq $timeout ]; then
return 1;
fi
stateful_running_pods=$(kubectl describe statefulset -n milvus milvus-ro-servers | grep "Pods Status:" | awk '{print $3}')
if [ $stateful_running_pods -ne $prev ]; then
PrintPodStatusMessage "Running milvus-ro-servers Pods: $stateful_running_pods/$stateful_replicas"
fi
prev=$stateful_running_pods
done;
return 0;
}
function checkDeployment() {
deployment_name=$1
replicas=$(kubectl describe deployment -n milvus $deployment_name | grep "Replicas:" | awk '{print $2}')
running=$(kubectl get pods -n milvus | grep $deployment_name | grep Running | wc -l)
counter=0
prev=$running
PrintPodStatusMessage "Running $deployment_name Pods: $running/$replicas"
while [ $replicas != $running ]; do
echo -e "${YELLOW}Wait another 1 sec --- ${counter}${ENDC}"
sleep 1;
let counter=counter+1
if [ $counter == $timeout ]; then
return 1
fi
running=$(kubectl get pods -n milvus | grep "$deployment_name" | grep Running | wc -l)
if [ $running -ne $prev ]; then
PrintPodStatusMessage "Running $deployment_name Pods: $running/$replicas"
fi
prev=$running
done
}
function startDependencies() {
kubectl apply -f milvus_data_pvc.yaml
kubectl apply -f milvus_configmap.yaml
kubectl apply -f milvus_auxiliary.yaml
counter=0
while [ $(kubectl get pvc -n milvus | grep Bound | wc -l) != 4 ]; do
sleep 1;
let counter=counter+1
if [ $counter == $timeout ]; then
echo "baseup timeout"
return 1
fi
done
checkDeployment "milvus-mysql"
}
function startApps() {
counter=0
errmsg=""
echo -e "${GREEN}${BOLD}Checking required resouces...${NORMAL}${ENDC}"
while [ $counter -lt $timeout ]; do
sleep 1;
if [ $(kubectl get pvc -n milvus 2>/dev/null | grep Bound | wc -l) != 4 ]; then
echo -e "${YELLOW}No pvc. Wait another sec... $counter${ENDC}";
errmsg='No pvc';
let counter=counter+1;
continue
fi
if [ $(kubectl get configmap -n milvus 2>/dev/null | grep milvus | wc -l) != 4 ]; then
echo -e "${YELLOW}No configmap. Wait another sec... $counter${ENDC}";
errmsg='No configmap';
let counter=counter+1;
continue
fi
if [ $(kubectl get ep -n milvus 2>/dev/null | grep milvus-mysql | awk '{print $2}') == "<none>" ]; then
echo -e "${YELLOW}No mysql. Wait another sec... $counter${ENDC}";
errmsg='No mysql';
let counter=counter+1;
continue
fi
# if [ $(kubectl get ep -n milvus 2>/dev/null | grep milvus-redis | awk '{print $2}') == "<none>" ]; then
# echo -e "${NORMAL}${YELLOW}No redis. Wait another sec... $counter${ENDC}";
# errmsg='No redis';
# let counter=counter+1;
# continue
# fi
break;
done
if [ $counter -ge $timeout ]; then
echo -e "${RED}${BOLD}Start APP Error: $errmsg${NORMAL}${ENDC}"
exit 1;
fi
echo -e "${GREEN}${BOLD}Setup requried database ...${NORMAL}${ENDC}"
setUpMysql
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Setup MySQL database timeout${NORMAL}${ENDC}"
exit 1
fi
echo -e "${GREEN}${BOLD}Start servers ...${NORMAL}${ENDC}"
kubectl apply -f milvus_stateful_servers.yaml
kubectl apply -f milvus_write_servers.yaml
checkStatefulSevers
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Starting milvus-ro-servers timeout${NORMAL}${ENDC}"
exit 1
fi
checkDeployment "milvus-wo-servers"
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Starting milvus-wo-servers timeout${NORMAL}${ENDC}"
exit 1
fi
echo -e "${GREEN}${BOLD}Start rolebinding ...${NORMAL}${ENDC}"
kubectl apply -f milvus_rbac.yaml
echo -e "${GREEN}${BOLD}Start proxies ...${NORMAL}${ENDC}"
kubectl apply -f milvus_proxy.yaml
checkDeployment "milvus-proxy"
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Starting milvus-proxy timeout${NORMAL}${ENDC}"
exit 1
fi
# echo -e "${GREEN}${BOLD}Start flower ...${NORMAL}${ENDC}"
# kubectl apply -f milvus_flower.yaml
# checkDeployment "milvus-flower"
# if [ $? -ne 0 ]; then
# echo -e "${RED}${BOLD}Starting milvus-flower timeout${NORMAL}${ENDC}"
# exit 1
# fi
}
function removeApps () {
# kubectl delete -f milvus_flower.yaml 2>/dev/null
kubectl delete -f milvus_proxy.yaml 2>/dev/null
kubectl delete -f milvus_stateful_servers.yaml 2>/dev/null
kubectl delete -f milvus_write_servers.yaml 2>/dev/null
kubectl delete -f milvus_rbac.yaml 2>/dev/null
# kubectl delete -f milvus_monitor.yaml 2>/dev/null
}
function scaleDeployment() {
deployment_name=$1
subcommand=$2
des=$3
case $des in
-h|--help|"")
showscaleHelpMessage $subcommand
exit 3
;;
esac
cur=$(kubectl get deployment -n milvus $deployment_name |grep $deployment_name |awk '{split($2, status, "/"); print status[2];}')
echo -e "${GREEN}Current Running ${BOLD}$cur ${GREEN}${deployment_name}, Scaling to ${BOLD}$des ...${ENDC}";
scalecmd="kubectl scale deployment -n milvus ${deployment_name} --replicas=${des}"
${scalecmd}
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Scale Error: ${GREEN}${scalecmd}${ENDC}"
exit 1
fi
checkDeployment $deployment_name
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Scale ${deployment_name} timeout${NORMAL}${ENDC}"
scalecmd="kubectl scale deployment -n milvus ${deployment_name} --replicas=${cur}"
${scalecmd}
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Scale Rollback Error: ${GREEN}${scalecmd}${ENDC}"
exit 2
fi
echo -e "${BLUE}${BOLD}Scale Rollback to ${cur}${ENDC}"
exit 1
fi
PrintScaleSuccessMessage $cur $des
}
function scaleROServers() {
subcommand=$1
des=$2
case $des in
-h|--help|"")
showscaleHelpMessage $subcommand
exit 3
;;
esac
cur=$(kubectl get statefulset -n milvus milvus-ro-servers |tail -n 1 |awk '{split($2, status, "/"); print status[2];}')
echo -e "${GREEN}Current Running ${BOLD}$cur ${GREEN}Readonly Servers, Scaling to ${BOLD}$des ...${ENDC}";
scalecmd="kubectl scale sts milvus-ro-servers -n milvus --replicas=${des}"
${scalecmd}
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Scale Error: ${GREEN}${scalecmd}${ENDC}"
exit 1
fi
checkStatefulSevers
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Scale milvus-ro-servers timeout${NORMAL}${ENDC}"
scalecmd="kubectl scale sts milvus-ro-servers -n milvus --replicas=${cur}"
${scalecmd}
if [ $? -ne 0 ]; then
echo -e "${RED}${BOLD}Scale Rollback Error: ${GREEN}${scalecmd}${ENDC}"
exit 2
fi
echo -e "${BLUE}${BOLD}Scale Rollback to ${cur}${ENDC}"
exit 1
fi
PrintScaleSuccessMessage $cur $des
}
case "$1" in
cleanup)
kubectl delete -f . 2>/dev/null
echo -e "${BLUE}${BOLD}All resources are removed${NORMAL}${ENDC}"
;;
appdown)
removeApps;
echo -e "${BLUE}${BOLD}All pods are removed${NORMAL}${ENDC}"
;;
baseup)
startDependencies;
echo -e "${BLUE}${BOLD}All pvc, configmap and services up${NORMAL}${ENDC}"
;;
appup)
startApps;
echo -e "${BLUE}${BOLD}All pods up${NORMAL}${ENDC}"
;;
allup)
startDependencies;
sleep 2
startApps;
echo -e "${BLUE}${BOLD}All resources and pods up${NORMAL}${ENDC}"
;;
scale-ro-server)
scaleROServers $1 $2
;;
scale-proxy)
scaleDeployment "milvus-proxy" $1 $2
;;
-h|--help|*)
showHelpMessage
;;
esac

18
shards/manager.py Normal file
View File

@ -0,0 +1,18 @@
import fire
from mishards import db, settings
class DBHandler:
@classmethod
def create_all(cls):
db.create_all()
@classmethod
def drop_all(cls):
db.drop_all()
if __name__ == '__main__':
db.init_db(settings.DefaultConfig.SQLALCHEMY_DATABASE_URI)
from mishards import models
fire.Fire(DBHandler)

View File

@ -0,0 +1,36 @@
DEBUG=True
WOSERVER=tcp://127.0.0.1:19530
SERVER_PORT=19532
SERVER_TEST_PORT=19888
#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
SQL_ECHO=False
#SQLALCHEMY_DATABASE_TEST_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
SQLALCHEMY_DATABASE_TEST_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
SQL_TEST_ECHO=False
TRACER_PLUGIN_PATH=/tmp/plugins
# TRACING_TEST_TYPE=jaeger
TRACER_CLASS_NAME=jaeger
TRACING_SERVICE_NAME=fortest
TRACING_SAMPLER_TYPE=const
TRACING_SAMPLER_PARAM=1
TRACING_LOG_PAYLOAD=True
#TRACING_SAMPLER_TYPE=probabilistic
#TRACING_SAMPLER_PARAM=0.5
#DISCOVERY_PLUGIN_PATH=
#DISCOVERY_CLASS_NAME=kubernetes
DISCOVERY_STATIC_HOSTS=127.0.0.1
DISCOVERY_STATIC_PORT=19530
DISCOVERY_KUBERNETES_NAMESPACE=xp
DISCOVERY_KUBERNETES_POD_PATT=.*-ro-servers-.*
DISCOVERY_KUBERNETES_LABEL_SELECTOR=tier=ro-servers
DISCOVERY_KUBERNETES_POLL_INTERVAL=5
DISCOVERY_KUBERNETES_IN_CLUSTER=False

View File

@ -0,0 +1,40 @@
import logging
from mishards import settings
logger = logging.getLogger()
from mishards.db_base import DB
db = DB()
from mishards.server import Server
grpc_server = Server()
def create_app(testing_config=None):
config = testing_config if testing_config else settings.DefaultConfig
db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO)
from mishards.connections import ConnectionMgr
connect_mgr = ConnectionMgr()
from discovery.factory import DiscoveryFactory
discover = DiscoveryFactory(config.DISCOVERY_PLUGIN_PATH).create(config.DISCOVERY_CLASS_NAME,
conn_mgr=connect_mgr)
from mishards.grpc_utils import GrpcSpanDecorator
from tracer.factory import TracerFactory
tracer = TracerFactory(config.TRACER_PLUGIN_PATH).create(config.TRACER_CLASS_NAME,
plugin_config=settings.TracingConfig,
span_decorator=GrpcSpanDecorator())
from mishards.router.factory import RouterFactory
router = RouterFactory(config.ROUTER_PLUGIN_PATH).create(config.ROUTER_CLASS_NAME,
conn_mgr=connect_mgr)
grpc_server.init_app(conn_mgr=connect_mgr,
tracer=tracer,
router=router,
discover=discover)
from mishards import exception_handlers
return grpc_server

View File

@ -0,0 +1,154 @@
import logging
import threading
from functools import wraps
from milvus import Milvus
from mishards import (settings, exceptions)
from utils import singleton
logger = logging.getLogger(__name__)
class Connection:
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
self.name = name
self.uri = uri
self.max_retry = max_retry
self.retried = 0
self.conn = Milvus()
self.error_handlers = [] if not error_handlers else error_handlers
self.on_retry_func = kwargs.get('on_retry_func', None)
# self._connect()
def __str__(self):
return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
def _connect(self, metadata=None):
try:
self.conn.connect(uri=self.uri)
except Exception as e:
if not self.error_handlers:
raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
for handler in self.error_handlers:
handler(e, metadata=metadata)
@property
def can_retry(self):
return self.retried < self.max_retry
@property
def connected(self):
return self.conn.connected()
def on_retry(self):
if self.on_retry_func:
self.on_retry_func(self)
else:
self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried))
def on_connect(self, metadata=None):
while not self.connected and self.can_retry:
self.retried += 1
self.on_retry()
self._connect(metadata=metadata)
if not self.can_retry and not self.connected:
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
metadata=metadata))
self.retried = 0
def connect(self, func, exception_handler=None):
@wraps(func)
def inner(*args, **kwargs):
self.on_connect()
try:
return func(*args, **kwargs)
except Exception as e:
if exception_handler:
exception_handler(e)
else:
raise e
return inner
@singleton
class ConnectionMgr:
def __init__(self):
self.metas = {}
self.conns = {}
@property
def conn_names(self):
return set(self.metas.keys()) - set(['WOSERVER'])
def conn(self, name, metadata, throw=False):
c = self.conns.get(name, None)
if not c:
url = self.metas.get(name, None)
if not url:
if not throw:
return None
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
threaded = {
threading.get_ident(): this_conn
}
self.conns[name] = threaded
return this_conn
tid = threading.get_ident()
rconn = c.get(tid, None)
if not rconn:
url = self.metas.get(name, None)
if not url:
if not throw:
return None
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
c[tid] = this_conn
return this_conn
return rconn
def on_new_meta(self, name, url):
logger.info('Register Connection: name={};url={}'.format(name, url))
self.metas[name] = url
def on_duplicate_meta(self, name, url):
if self.metas[name] == url:
return self.on_same_meta(name, url)
return self.on_diff_meta(name, url)
def on_same_meta(self, name, url):
# logger.warning('Register same meta: {}:{}'.format(name, url))
pass
def on_diff_meta(self, name, url):
logger.warning('Received {} with diff url={}'.format(name, url))
self.metas[name] = url
self.conns[name] = {}
def on_unregister_meta(self, name, url):
logger.info('Unregister name={};url={}'.format(name, url))
self.conns.pop(name, None)
def on_nonexisted_meta(self, name):
logger.warning('Non-existed meta: {}'.format(name))
def register(self, name, url):
meta = self.metas.get(name)
if not meta:
return self.on_new_meta(name, url)
else:
return self.on_duplicate_meta(name, url)
def unregister(self, name):
logger.info('Unregister Connection: name={}'.format(name))
url = self.metas.pop(name, None)
if url is None:
return self.on_nonexisted_meta(name)
return self.on_unregister_meta(name, url)

View File

@ -0,0 +1,52 @@
import logging
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm.session import Session as SessionBase
logger = logging.getLogger(__name__)
class LocalSession(SessionBase):
def __init__(self, db, autocommit=False, autoflush=True, **options):
self.db = db
bind = options.pop('bind', None) or db.engine
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options)
class DB:
Model = declarative_base()
def __init__(self, uri=None, echo=False):
self.echo = echo
uri and self.init_db(uri, echo)
self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self))
def init_db(self, uri, echo=False):
url = make_url(uri)
if url.get_backend_name() == 'sqlite':
self.engine = create_engine(url)
else:
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
pool_pre_ping=True,
echo=echo,
max_overflow=0)
self.uri = uri
self.url = url
def __str__(self):
return '<DB: backend={};database={}>'.format(self.url.get_backend_name(), self.url.database)
@property
def Session(self):
return self.session_factory()
def remove_session(self):
self.session_factory.remove()
def drop_all(self):
self.Model.metadata.drop_all(self.engine)
def create_all(self):
self.Model.metadata.create_all(self.engine)

View File

@ -0,0 +1,10 @@
INVALID_CODE = -1
CONNECT_ERROR_CODE = 10001
CONNECTTION_NOT_FOUND_CODE = 10002
DB_ERROR_CODE = 10003
TABLE_NOT_FOUND_CODE = 20001
INVALID_ARGUMENT_CODE = 20002
INVALID_DATE_RANGE_CODE = 20003
INVALID_TOPK_CODE = 20004

View File

@ -0,0 +1,82 @@
import logging
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from mishards import grpc_server as server, exceptions
logger = logging.getLogger(__name__)
def resp_handler(err, error_code):
if not isinstance(err, exceptions.BaseException):
return status_pb2.Status(error_code=error_code, reason=str(err))
status = status_pb2.Status(error_code=error_code, reason=err.message)
if err.metadata is None:
return status
resp_class = err.metadata.get('resp_class', None)
if not resp_class:
return status
if resp_class == milvus_pb2.BoolReply:
return resp_class(status=status, bool_reply=False)
if resp_class == milvus_pb2.VectorIds:
return resp_class(status=status, vector_id_array=[])
if resp_class == milvus_pb2.TopKQueryResultList:
return resp_class(status=status, topk_query_result=[])
if resp_class == milvus_pb2.TableRowCount:
return resp_class(status=status, table_row_count=-1)
if resp_class == milvus_pb2.TableName:
return resp_class(status=status, table_name=[])
if resp_class == milvus_pb2.StringReply:
return resp_class(status=status, string_reply='')
if resp_class == milvus_pb2.TableSchema:
return milvus_pb2.TableSchema(
status=status
)
if resp_class == milvus_pb2.IndexParam:
return milvus_pb2.IndexParam(
table_name=milvus_pb2.TableName(
status=status
)
)
status.error_code = status_pb2.UNEXPECTED_ERROR
return status
@server.errorhandler(exceptions.TableNotFoundError)
def TableNotFoundErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
@server.errorhandler(exceptions.InvalidTopKError)
def InvalidTopKErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_TOPK)
@server.errorhandler(exceptions.InvalidArgumentError)
def InvalidArgumentErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)
@server.errorhandler(exceptions.DBError)
def DBErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.UNEXPECTED_ERROR)
@server.errorhandler(exceptions.InvalidRangeError)
def InvalidArgumentErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_RANGE)

View File

@ -0,0 +1,38 @@
import mishards.exception_codes as codes
class BaseException(Exception):
code = codes.INVALID_CODE
message = 'BaseException'
def __init__(self, message='', metadata=None):
self.message = self.__class__.__name__ if not message else message
self.metadata = metadata
class ConnectionConnectError(BaseException):
code = codes.CONNECT_ERROR_CODE
class ConnectionNotFoundError(BaseException):
code = codes.CONNECTTION_NOT_FOUND_CODE
class DBError(BaseException):
code = codes.DB_ERROR_CODE
class TableNotFoundError(BaseException):
code = codes.TABLE_NOT_FOUND_CODE
class InvalidTopKError(BaseException):
code = codes.INVALID_TOPK_CODE
class InvalidArgumentError(BaseException):
code = codes.INVALID_ARGUMENT_CODE
class InvalidRangeError(BaseException):
code = codes.INVALID_DATE_RANGE_CODE

View File

@ -0,0 +1,54 @@
import time
import datetime
import random
import factory
from factory.alchemy import SQLAlchemyModelFactory
from faker import Faker
from faker.providers import BaseProvider
from milvus.client.types import MetricType
from mishards import db
from mishards.models import Tables, TableFiles
class FakerProvider(BaseProvider):
def this_date(self):
t = datetime.datetime.today()
return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day
factory.Faker.add_provider(FakerProvider)
class TablesFactory(SQLAlchemyModelFactory):
class Meta:
model = Tables
sqlalchemy_session = db.session_factory
sqlalchemy_session_persistence = 'commit'
id = factory.Faker('random_number', digits=16, fix_len=True)
table_id = factory.Faker('uuid4')
state = factory.Faker('random_element', elements=(0, 1))
dimension = factory.Faker('random_element', elements=(256, 512))
created_on = int(time.time())
index_file_size = 0
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
metric_type = factory.Faker('random_element', elements=(MetricType.L2, MetricType.IP))
nlist = 16384
class TableFilesFactory(SQLAlchemyModelFactory):
class Meta:
model = TableFiles
sqlalchemy_session = db.session_factory
sqlalchemy_session_persistence = 'commit'
id = factory.Faker('random_number', digits=16, fix_len=True)
table = factory.SubFactory(TablesFactory)
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
file_id = factory.Faker('uuid4')
file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4))
file_size = factory.Faker('random_number')
updated_time = int(time.time())
created_on = int(time.time())
date = factory.Faker('this_date')

View File

@ -0,0 +1,37 @@
from grpc_opentracing import SpanDecorator
from milvus.grpc_gen import status_pb2
class GrpcSpanDecorator(SpanDecorator):
def __call__(self, span, rpc_info):
status = None
if not rpc_info.response:
return
if isinstance(rpc_info.response, status_pb2.Status):
status = rpc_info.response
else:
try:
status = rpc_info.response.status
except Exception as e:
status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR,
reason='Should not happen')
if status.error_code == 0:
return
error_log = {'event': 'error',
'request': rpc_info.request,
'response': rpc_info.response
}
span.set_tag('error', True)
span.log_kv(error_log)
def mark_grpc_method(func):
setattr(func, 'grpc_method', True)
return func
def is_grpc_method(func):
if not func:
return False
return getattr(func, 'grpc_method', False)

View File

@ -0,0 +1,102 @@
from milvus import Status
from functools import wraps
def error_status(func):
@wraps(func)
def inner(*args, **kwargs):
try:
results = func(*args, **kwargs)
except Exception as e:
return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None
return Status(code=0, message="Success"), results
return inner
class GrpcArgsParser(object):
@classmethod
@error_status
def parse_proto_TableSchema(cls, param):
_table_schema = {
'status': param.status,
'table_name': param.table_name,
'dimension': param.dimension,
'index_file_size': param.index_file_size,
'metric_type': param.metric_type
}
return _table_schema
@classmethod
@error_status
def parse_proto_TableName(cls, param):
return param.table_name
@classmethod
@error_status
def parse_proto_Index(cls, param):
_index = {
'index_type': param.index_type,
'nlist': param.nlist
}
return _index
@classmethod
@error_status
def parse_proto_IndexParam(cls, param):
_table_name = param.table_name
_status, _index = cls.parse_proto_Index(param.index)
if not _status.OK():
raise Exception("Argument parse error")
return _table_name, _index
@classmethod
@error_status
def parse_proto_Command(cls, param):
_cmd = param.cmd
return _cmd
@classmethod
@error_status
def parse_proto_Range(cls, param):
_start_value = param.start_value
_end_value = param.end_value
return _start_value, _end_value
@classmethod
@error_status
def parse_proto_RowRecord(cls, param):
return list(param.vector_data)
@classmethod
@error_status
def parse_proto_SearchParam(cls, param):
_table_name = param.table_name
_topk = param.topk
_nprobe = param.nprobe
_status, _range = cls.parse_proto_Range(param.query_range_array)
if not _status.OK():
raise Exception("Argument parse error")
_row_record = param.query_record_array
return _table_name, _row_record, _range, _topk
@classmethod
@error_status
def parse_proto_DeleteByRangeParam(cls, param):
_table_name = param.table_name
_range = param.range
_start_value = _range.start_value
_end_value = _range.end_value
return _table_name, _start_value, _end_value

View File

@ -0,0 +1,4 @@
# class GrpcArgsWrapper(object):
# @classmethod
# def proto_TableName(cls):

View File

@ -0,0 +1,75 @@
import logging
import opentracing
from mishards.grpc_utils import GrpcSpanDecorator, is_grpc_method
from milvus.grpc_gen import status_pb2, milvus_pb2
logger = logging.getLogger(__name__)
class FakeTracer(opentracing.Tracer):
pass
class FakeSpan(opentracing.Span):
def __init__(self, context, tracer, **kwargs):
super(FakeSpan, self).__init__(tracer, context)
self.reset()
def set_tag(self, key, value):
self.tags.append({key: value})
def log_kv(self, key_values, timestamp=None):
self.logs.append(key_values)
def reset(self):
self.tags = []
self.logs = []
class FakeRpcInfo:
def __init__(self, request, response):
self.request = request
self.response = response
class TestGrpcUtils:
def test_span_deco(self):
request = 'request'
OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success')
response = OK
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
response = milvus_pb2.BoolReply(status=OK, bool_reply=False)
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
response = 1
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 1
assert len(span.tags) == 1
response = 0
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
def test_is_grpc_method(self):
target = 1
assert not is_grpc_method(target)
target = None
assert not is_grpc_method(target)

View File

@ -0,0 +1,150 @@
import math
import sys
from bisect import bisect
if sys.version_info >= (2, 5):
import hashlib
md5_constructor = hashlib.md5
else:
import md5
md5_constructor = md5.new
class HashRing(object):
def __init__(self, nodes=None, weights=None):
"""`nodes` is a list of objects that have a proper __str__ representation.
`weights` is dictionary that sets weights to the nodes. The default
weight is that all nodes are equal.
"""
self.ring = dict()
self._sorted_keys = []
self.nodes = nodes
if not weights:
weights = {}
self.weights = weights
self._generate_circle()
def _generate_circle(self):
"""Generates the circle.
"""
total_weight = 0
for node in self.nodes:
total_weight += self.weights.get(node, 1)
for node in self.nodes:
weight = 1
if node in self.weights:
weight = self.weights.get(node)
factor = math.floor((40 * len(self.nodes) * weight) / total_weight)
for j in range(0, int(factor)):
b_key = self._hash_digest('%s-%s' % (node, j))
for i in range(0, 3):
key = self._hash_val(b_key, lambda x: x + i * 4)
self.ring[key] = node
self._sorted_keys.append(key)
self._sorted_keys.sort()
def get_node(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned.
If the hash ring is empty, `None` is returned.
"""
pos = self.get_node_pos(string_key)
if pos is None:
return None
return self.ring[self._sorted_keys[pos]]
def get_node_pos(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned
along with it's position in the ring.
If the hash ring is empty, (`None`, `None`) is returned.
"""
if not self.ring:
return None
key = self.gen_key(string_key)
nodes = self._sorted_keys
pos = bisect(nodes, key)
if pos == len(nodes):
return 0
else:
return pos
def iterate_nodes(self, string_key, distinct=True):
"""Given a string key it returns the nodes as a generator that can hold the key.
The generator iterates one time through the ring
starting at the correct position.
if `distinct` is set, then the nodes returned will be unique,
i.e. no virtual copies will be returned.
"""
if not self.ring:
yield None, None
returned_values = set()
def distinct_filter(value):
if str(value) not in returned_values:
returned_values.add(str(value))
return value
pos = self.get_node_pos(string_key)
for key in self._sorted_keys[pos:]:
val = distinct_filter(self.ring[key])
if val:
yield val
for i, key in enumerate(self._sorted_keys):
if i < pos:
val = distinct_filter(self.ring[key])
if val:
yield val
def gen_key(self, key):
"""Given a string key it returns a long value,
this long value represents a place on the hash ring.
md5 is currently used because it mixes well.
"""
b_key = self._hash_digest(key)
return self._hash_val(b_key, lambda x: x)
def _hash_val(self, b_key, entry_fn):
return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | (
b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)]
def _hash_digest(self, key):
m = md5_constructor()
key = key.encode()
m.update(key)
return m.digest()
if __name__ == '__main__':
from collections import defaultdict
servers = [
'192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212',
'192.168.0.249:11212'
]
ring = HashRing(servers)
keys = ['{}'.format(i) for i in range(100)]
mapped = defaultdict(list)
for k in keys:
server = ring.get_node(k)
mapped[server].append(k)
for k, v in mapped.items():
print(k, v)

15
shards/mishards/main.py Normal file
View File

@ -0,0 +1,15 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from mishards import (settings, create_app)
def main():
server = create_app(settings.DefaultConfig)
server.run(port=settings.SERVER_PORT)
return 0
if __name__ == '__main__':
sys.exit(main())

76
shards/mishards/models.py Normal file
View File

@ -0,0 +1,76 @@
import logging
from sqlalchemy import (Integer, Boolean, Text,
String, BigInteger, and_, or_,
Column)
from sqlalchemy.orm import relationship, backref
from mishards import db
logger = logging.getLogger(__name__)
class TableFiles(db.Model):
FILE_TYPE_NEW = 0
FILE_TYPE_RAW = 1
FILE_TYPE_TO_INDEX = 2
FILE_TYPE_INDEX = 3
FILE_TYPE_TO_DELETE = 4
FILE_TYPE_NEW_MERGE = 5
FILE_TYPE_NEW_INDEX = 6
FILE_TYPE_BACKUP = 7
__tablename__ = 'TableFiles'
id = Column(BigInteger, primary_key=True, autoincrement=True)
table_id = Column(String(50))
engine_type = Column(Integer)
file_id = Column(String(50))
file_type = Column(Integer)
file_size = Column(Integer, default=0)
row_count = Column(Integer, default=0)
updated_time = Column(BigInteger)
created_on = Column(BigInteger)
date = Column(Integer)
table = relationship(
'Tables',
primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)',
backref=backref('files', uselist=True, lazy='dynamic')
)
class Tables(db.Model):
TO_DELETE = 1
NORMAL = 0
__tablename__ = 'Tables'
id = Column(BigInteger, primary_key=True, autoincrement=True)
table_id = Column(String(50), unique=True)
state = Column(Integer)
dimension = Column(Integer)
created_on = Column(Integer)
flag = Column(Integer, default=0)
index_file_size = Column(Integer)
engine_type = Column(Integer)
nlist = Column(Integer)
metric_type = Column(Integer)
def files_to_search(self, date_range=None):
cond = or_(
TableFiles.file_type == TableFiles.FILE_TYPE_RAW,
TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX,
TableFiles.file_type == TableFiles.FILE_TYPE_INDEX,
)
if date_range:
cond = and_(
cond,
or_(
and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range
)
)
files = self.files.filter(cond)
logger.debug('DATE_RANGE: {}'.format(date_range))
return files

View File

@ -0,0 +1,22 @@
from mishards import exceptions
class RouterMixin:
def __init__(self, conn_mgr):
self.conn_mgr = conn_mgr
def routing(self, table_name, metadata=None, **kwargs):
raise NotImplemented()
def connection(self, metadata=None):
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
if conn:
conn.on_connect(metadata=metadata)
return conn.conn
def query_conn(self, name, metadata=None):
conn = self.conn_mgr.conn(name, metadata=metadata)
if not conn:
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
conn.on_connect(metadata=metadata)
return conn.conn

View File

@ -0,0 +1,17 @@
import os
import logging
from utils.plugins import BaseMixin
logger = logging.getLogger(__name__)
PLUGIN_PACKAGE_NAME = 'mishards.router.plugins'
class RouterFactory(BaseMixin):
PLUGIN_TYPE = 'Router'
def __init__(self, searchpath=None):
super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME)
def _create(self, plugin_class, **kwargs):
router = plugin_class.Create(**kwargs)
return router

View File

@ -0,0 +1,64 @@
import logging
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import and_
from mishards.models import Tables
from mishards.router import RouterMixin
from mishards import exceptions, db
from mishards.hash_ring import HashRing
logger = logging.getLogger(__name__)
class Factory(RouterMixin):
name = 'FileBasedHashRingRouter'
def __init__(self, conn_mgr, **kwargs):
super(Factory, self).__init__(conn_mgr)
def routing(self, table_name, metadata=None, **kwargs):
range_array = kwargs.pop('range_array', None)
return self._route(table_name, range_array, metadata, **kwargs)
def _route(self, table_name, range_array, metadata=None, **kwargs):
# PXU TODO: Implement Thread-local Context
# PXU TODO: Session life mgt
try:
table = db.Session.query(Tables).filter(
and_(Tables.table_id == table_name,
Tables.state != Tables.TO_DELETE)).first()
except sqlalchemy_exc.SQLAlchemyError as e:
raise exceptions.DBError(message=str(e), metadata=metadata)
if not table:
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
files = table.files_to_search(range_array)
db.remove_session()
servers = self.conn_mgr.conn_names
logger.info('Available servers: {}'.format(servers))
ring = HashRing(servers)
routing = {}
for f in files:
target_host = ring.get_node(str(f.id))
sub = routing.get(target_host, None)
if not sub:
routing[target_host] = {'table_id': table_name, 'file_ids': []}
routing[target_host]['file_ids'].append(str(f.id))
return routing
@classmethod
def Create(cls, **kwargs):
conn_mgr = kwargs.pop('conn_mgr', None)
if not conn_mgr:
raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name))
router = cls(conn_mgr, **kwargs)
return router
def setup(app):
logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name))
app.on_plugin_setup(Factory)

122
shards/mishards/server.py Normal file
View File

@ -0,0 +1,122 @@
import logging
import grpc
import time
import socket
import inspect
from urllib.parse import urlparse
from functools import wraps
from concurrent import futures
from grpc._cython import cygrpc
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
from mishards.grpc_utils import is_grpc_method
from mishards.service_handler import ServiceHandler
from mishards import settings
logger = logging.getLogger(__name__)
class Server:
def __init__(self):
self.pre_run_handlers = set()
self.grpc_methods = set()
self.error_handlers = {}
self.exit_flag = False
def init_app(self,
conn_mgr,
tracer,
router,
discover,
port=19530,
max_workers=10,
**kwargs):
self.port = int(port)
self.conn_mgr = conn_mgr
self.tracer = tracer
self.router = router
self.discover = discover
self.server_impl = grpc.server(
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
(cygrpc.ChannelArgKey.max_receive_message_length, -1)])
self.server_impl = self.tracer.decorate(self.server_impl)
self.register_pre_run_handler(self.pre_run_handler)
def pre_run_handler(self):
woserver = settings.WOSERVER
url = urlparse(woserver)
ip = socket.gethostbyname(url.hostname)
socket.inet_pton(socket.AF_INET, ip)
self.conn_mgr.register(
'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80))
def register_pre_run_handler(self, func):
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
self.pre_run_handlers.add(func)
return func
def wrap_method_with_errorhandler(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if e.__class__ in self.error_handlers:
return self.error_handlers[e.__class__](e)
raise
return wrapper
def errorhandler(self, exception):
if inspect.isclass(exception) and issubclass(exception, Exception):
def wrapper(func):
self.error_handlers[exception] = func
return func
return wrapper
return exception
def on_pre_run(self):
for handler in self.pre_run_handlers:
handler()
self.discover.start()
def start(self, port=None):
handler_class = self.decorate_handler(ServiceHandler)
add_MilvusServiceServicer_to_server(
handler_class(tracer=self.tracer,
router=self.router), self.server_impl)
self.server_impl.add_insecure_port("[::]:{}".format(
str(port or self.port)))
self.server_impl.start()
def run(self, port):
logger.info('Milvus server start ......')
port = port or self.port
self.on_pre_run()
self.start(port)
logger.info('Listening on port {}'.format(port))
try:
while not self.exit_flag:
time.sleep(5)
except KeyboardInterrupt:
self.stop()
def stop(self):
logger.info('Server is shuting down ......')
self.exit_flag = True
self.server_impl.stop(0)
self.tracer.close()
logger.info('Server is closed')
def decorate_handler(self, handler):
for key, attr in handler.__dict__.items():
if is_grpc_method(attr):
setattr(handler, key, self.wrap_method_with_errorhandler(attr))
return handler

View File

@ -0,0 +1,475 @@
import logging
import time
import datetime
from collections import defaultdict
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
from milvus.client.abstract import Range
from milvus.client import types as Types
from mishards import (db, settings, exceptions)
from mishards.grpc_utils import mark_grpc_method
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards import utilities
logger = logging.getLogger(__name__)
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
MAX_NPROBE = 2048
MAX_TOPK = 2048
def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs):
self.table_meta = {}
self.error_handlers = {}
self.tracer = tracer
self.router = router
self.max_workers = max_workers
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
reason="Success")
if not files_n_topk_results:
return status, []
request_results = defaultdict(list)
calc_time = time.time()
for files_collection in files_n_topk_results:
if isinstance(files_collection, tuple):
status, _ = files_collection
return status, []
for request_pos, each_request_results in enumerate(
files_collection.topk_query_result):
request_results[request_pos].extend(
each_request_results.query_result_arrays)
request_results[request_pos] = sorted(
request_results[request_pos],
key=lambda x: x.distance,
reverse=reverse)[:topk]
calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time))
results = sorted(request_results.items())
topk_query_result = []
for result in results:
query_result = TopKQueryResult(query_result_arrays=result[1])
topk_query_result.append(query_result)
return status, topk_query_result
def _do_query(self,
context,
table_id,
table_meta,
vectors,
topk,
nprobe,
range_array=None,
**kwargs):
metadata = kwargs.get('metadata', None)
range_array = [
utilities.range_to_date(r, metadata=metadata) for r in range_array
] if range_array else None
routing = {}
p_span = None if self.tracer.empty else context.get_active_span(
).context
with self.tracer.start_span('get_routing', child_of=p_span):
routing = self.router.routing(table_id,
range_array=range_array,
metadata=metadata)
logger.info('Routing: {}'.format(routing))
metadata = kwargs.get('metadata', None)
rs = []
all_topk_results = []
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
logger.info(
'Send Search Request: addr={};params={};nq={};topk={};nprobe={}'
.format(addr, query_params, len(vectors), topk, nprobe))
conn = self.router.query_conn(addr, metadata=metadata)
start = time.time()
span = kwargs.get('span', None)
span = span if span else (None if self.tracer.empty else
context.get_active_span().context)
with self.tracer.start_span('search_{}'.format(addr),
child_of=span):
ret = conn.search_vectors_in_files(
table_name=query_params['table_id'],
file_ids=query_params['file_ids'],
query_records=vectors,
top_k=topk,
nprobe=nprobe,
lazy_=True)
end = time.time()
logger.info('search_vectors_in_files takes: {}'.format(end - start))
all_topk_results.append(ret)
with self.tracer.start_span('do_search', child_of=p_span) as span:
with ThreadPoolExecutor(max_workers=self.max_workers) as pool:
for addr, params in routing.items():
res = pool.submit(search,
addr,
params,
vectors,
topk,
nprobe,
span=span)
rs.append(res)
for res in rs:
res.result()
reverse = table_meta.metric_type == Types.MetricType.IP
with self.tracer.start_span('do_merge', child_of=p_span):
return self._do_merge(all_topk_results,
topk,
reverse=reverse,
metadata=metadata)
def _create_table(self, table_schema):
return self.router.connection().create_table(table_schema)
@mark_grpc_method
def CreateTable(self, request, context):
_status, _table_schema = Parser.parse_proto_TableSchema(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('CreateTable {}'.format(_table_schema['table_name']))
_status = self._create_table(_table_schema)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _has_table(self, table_name, metadata=None):
return self.router.connection(metadata=metadata).has_table(table_name)
@mark_grpc_method
def HasTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.BoolReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
bool_reply=False)
logger.info('HasTable {}'.format(_table_name))
_status, _bool = self._has_table(_table_name,
metadata={'resp_class': milvus_pb2.BoolReply})
return milvus_pb2.BoolReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
bool_reply=_bool)
def _delete_table(self, table_name):
return self.router.connection().delete_table(table_name)
@mark_grpc_method
def DropTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('DropTable {}'.format(_table_name))
_status = self._delete_table(_table_name)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _create_index(self, table_name, index):
return self.router.connection().create_index(table_name, index)
@mark_grpc_method
def CreateIndex(self, request, context):
_status, unpacks = Parser.parse_proto_IndexParam(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
_table_name, _index = unpacks
logger.info('CreateIndex {}'.format(_table_name))
# TODO: interface create_table incompleted
_status = self._create_index(_table_name, _index)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _add_vectors(self, param, metadata=None):
return self.router.connection(metadata=metadata).add_vectors(
None, None, insert_param=param)
@mark_grpc_method
def Insert(self, request, context):
logger.info('Insert')
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
_status, _ids = self._add_vectors(
metadata={'resp_class': milvus_pb2.VectorIds}, param=request)
return milvus_pb2.VectorIds(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
vector_id_array=_ids)
@mark_grpc_method
def Search(self, request, context):
table_name = request.table_name
topk = request.topk
nprobe = request.nprobe
logger.info('Search {}: topk={} nprobe={}'.format(
table_name, topk, nprobe))
metadata = {'resp_class': milvus_pb2.TopKQueryResultList}
if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.InvalidArgumentError(
message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)
if topk > self.MAX_TOPK or topk <= 0:
raise exceptions.InvalidTopKError(
message='Invalid topk: {}'.format(topk), metadata=metadata)
table_meta = self.table_meta.get(table_name, None)
if not table_meta:
status, info = self.router.connection(
metadata=metadata).describe_table(table_name)
if not status.OK():
raise exceptions.TableNotFoundError(table_name,
metadata=metadata)
self.table_meta[table_name] = info
table_meta = info
start = time.time()
query_record_array = []
for query_record in request.query_record_array:
query_record_array.append(list(query_record.vector_data))
query_range_array = []
for query_range in request.query_range_array:
query_range_array.append(
Range(query_range.start_value, query_range.end_value))
status, results = self._do_query(context,
table_name,
table_meta,
query_record_array,
topk,
nprobe,
query_range_array,
metadata=metadata)
now = time.time()
logger.info('SearchVector takes: {}'.format(now - start))
topk_result_list = milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=status.error_code,
reason=status.reason),
topk_query_result=results)
return topk_result_list
@mark_grpc_method
def SearchInFiles(self, request, context):
raise NotImplemented()
def _describe_table(self, table_name, metadata=None):
return self.router.connection(metadata=metadata).describe_table(table_name)
@mark_grpc_method
def DescribeTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.TableSchema(status=status_pb2.Status(
error_code=_status.code, reason=_status.message), )
metadata = {'resp_class': milvus_pb2.TableSchema}
logger.info('DescribeTable {}'.format(_table_name))
_status, _table = self._describe_table(metadata=metadata,
table_name=_table_name)
if _status.OK():
return milvus_pb2.TableSchema(
table_name=_table_name,
index_file_size=_table.index_file_size,
dimension=_table.dimension,
metric_type=_table.metric_type,
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
)
return milvus_pb2.TableSchema(
table_name=_table_name,
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
)
def _count_table(self, table_name, metadata=None):
return self.router.connection(
metadata=metadata).get_table_row_count(table_name)
@mark_grpc_method
def CountTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
status = status_pb2.Status(error_code=_status.code,
reason=_status.message)
return milvus_pb2.TableRowCount(status=status)
logger.info('CountTable {}'.format(_table_name))
metadata = {'resp_class': milvus_pb2.TableRowCount}
_status, _count = self._count_table(_table_name, metadata=metadata)
return milvus_pb2.TableRowCount(
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
table_row_count=_count if isinstance(_count, int) else -1)
def _get_server_version(self, metadata=None):
return self.router.connection(metadata=metadata).server_version()
@mark_grpc_method
def Cmd(self, request, context):
_status, _cmd = Parser.parse_proto_Command(request)
logger.info('Cmd: {}'.format(_cmd))
if not _status.OK():
return milvus_pb2.StringReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message))
metadata = {'resp_class': milvus_pb2.StringReply}
if _cmd == 'version':
_status, _reply = self._get_server_version(metadata=metadata)
else:
_status, _reply = self.router.connection(
metadata=metadata).server_status()
return milvus_pb2.StringReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
string_reply=_reply)
def _show_tables(self, metadata=None):
return self.router.connection(metadata=metadata).show_tables()
@mark_grpc_method
def ShowTables(self, request, context):
logger.info('ShowTables')
metadata = {'resp_class': milvus_pb2.TableName}
_status, _results = self._show_tables(metadata=metadata)
return milvus_pb2.TableNameList(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
table_names=_results)
def _delete_by_range(self, table_name, start_date, end_date):
return self.router.connection().delete_vectors_by_range(table_name,
start_date,
end_date)
@mark_grpc_method
def DeleteByRange(self, request, context):
_status, unpacks = \
Parser.parse_proto_DeleteByRangeParam(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
_table_name, _start_date, _end_date = unpacks
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date,
_end_date))
_status = self._delete_by_range(_table_name, _start_date, _end_date)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _preload_table(self, table_name):
return self.router.connection().preload_table(table_name)
@mark_grpc_method
def PreloadTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('PreloadTable {}'.format(_table_name))
_status = self._preload_table(_table_name)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _describe_index(self, table_name, metadata=None):
return self.router.connection(metadata=metadata).describe_index(table_name)
@mark_grpc_method
def DescribeIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.IndexParam(status=status_pb2.Status(
error_code=_status.code, reason=_status.message))
metadata = {'resp_class': milvus_pb2.IndexParam}
logger.info('DescribeIndex {}'.format(_table_name))
_status, _index_param = self._describe_index(table_name=_table_name,
metadata=metadata)
if not _index_param:
return milvus_pb2.IndexParam(status=status_pb2.Status(
error_code=_status.code, reason=_status.message))
_index = milvus_pb2.Index(index_type=_index_param._index_type,
nlist=_index_param._nlist)
return milvus_pb2.IndexParam(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
table_name=_table_name,
index=_index)
def _drop_index(self, table_name):
return self.router.connection().drop_index(table_name)
@mark_grpc_method
def DropIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('DropIndex {}'.format(_table_name))
_status = self._drop_index(_table_name)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)

View File

@ -0,0 +1,69 @@
import sys
import os
from environs import Env
env = Env()
FROM_EXAMPLE = env.bool('FROM_EXAMPLE', False)
if FROM_EXAMPLE:
from dotenv import load_dotenv
load_dotenv('./mishards/.env.example')
else:
env.read_env()
DEBUG = env.bool('DEBUG', False)
MAX_RETRY = env.int('MAX_RETRY', 3)
LOG_LEVEL = env.str('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO')
LOG_PATH = env.str('LOG_PATH', '/tmp/mishards')
LOG_NAME = env.str('LOG_NAME', 'logfile')
TIMEZONE = env.str('TIMEZONE', 'UTC')
from utils.logger_helper import config
config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
SERVER_PORT = env.int('SERVER_PORT', 19530)
SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530)
WOSERVER = env.str('WOSERVER')
class TracingConfig:
TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards')
TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True)
TRACING_LOG_PAYLOAD = env.bool('TRACING_LOG_PAYLOAD', False)
TRACING_CONFIG = {
'sampler': {
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
'param': env.str('TRACING_SAMPLER_PARAM', "1"),
},
'local_agent': {
'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'),
'reporting_port': env.str('TRACING_REPORTING_PORT', '5775')
},
'logging': env.bool('TRACING_LOGGING', True)
}
DEFAULT_TRACING_CONFIG = {
'sampler': {
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
'param': env.str('TRACING_SAMPLER_PARAM', "0"),
}
}
class DefaultConfig:
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
SQL_ECHO = env.bool('SQL_ECHO', False)
TRACER_PLUGIN_PATH = env.str('TRACER_PLUGIN_PATH', '')
TRACER_CLASS_NAME = env.str('TRACER_CLASS_NAME', '')
ROUTER_PLUGIN_PATH = env.str('ROUTER_PLUGIN_PATH', '')
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter')
DISCOVERY_PLUGIN_PATH = env.str('DISCOVERY_PLUGIN_PATH', '')
DISCOVERY_CLASS_NAME = env.str('DISCOVERY_CLASS_NAME', 'static')
class TestingConfig(DefaultConfig):
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI', '')
SQL_ECHO = env.bool('SQL_TEST_ECHO', False)
TRACER_CLASS_NAME = env.str('TRACER_CLASS_TEST_NAME', '')
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter')

View File

@ -0,0 +1,101 @@
import logging
import pytest
import mock
from milvus import Milvus
from mishards.connections import (ConnectionMgr, Connection)
from mishards import exceptions
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestConnection:
def test_manager(self):
mgr = ConnectionMgr()
mgr.register('pod1', '111')
mgr.register('pod2', '222')
mgr.register('pod2', '222')
mgr.register('pod2', '2222')
assert len(mgr.conn_names) == 2
mgr.unregister('pod1')
assert len(mgr.conn_names) == 1
mgr.unregister('pod2')
assert len(mgr.conn_names) == 0
mgr.register('WOSERVER', 'xxxx')
assert len(mgr.conn_names) == 0
assert not mgr.conn('XXXX', None)
with pytest.raises(exceptions.ConnectionNotFoundError):
mgr.conn('XXXX', None, True)
mgr.conn('WOSERVER', None)
def test_connection(self):
class Conn:
def __init__(self, state):
self.state = state
def connect(self, uri):
return self.state
def connected(self):
return self.state
FAIL_CONN = Conn(False)
PASS_CONN = Conn(True)
class Retry:
def __init__(self):
self.times = 0
def __call__(self, conn):
self.times += 1
logger.info('Retrying {}'.format(self.times))
class Func():
def __init__(self):
self.executed = False
def __call__(self):
self.executed = True
max_retry = 3
RetryObj = Retry()
c = Connection('client',
uri='xx',
max_retry=max_retry,
on_retry_func=RetryObj)
c.conn = FAIL_CONN
ff = Func()
this_connect = c.connect(func=ff)
with pytest.raises(exceptions.ConnectionConnectError):
this_connect()
assert RetryObj.times == max_retry
assert not ff.executed
RetryObj = Retry()
c.conn = PASS_CONN
this_connect = c.connect(func=ff)
this_connect()
assert ff.executed
assert RetryObj.times == 0
this_connect = c.connect(func=None)
with pytest.raises(TypeError):
this_connect()
errors = []
def error_handler(err):
errors.append(err)
this_connect = c.connect(func=None, exception_handler=error_handler)
this_connect()
assert len(errors) == 1

View File

@ -0,0 +1,39 @@
import logging
import pytest
from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory
from mishards import db, create_app, settings
from mishards.factories import (
Tables, TableFiles,
TablesFactory, TableFilesFactory
)
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestModels:
def test_files_to_search(self):
table = TablesFactory()
new_files_cnt = 5
to_index_cnt = 10
raw_cnt = 20
backup_cnt = 12
to_delete_cnt = 9
index_cnt = 8
new_index_cnt = 6
new_merge_cnt = 11
new_files = TableFilesFactory.create_batch(new_files_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW, date=110)
to_index_files = TableFilesFactory.create_batch(to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX, date=110)
raw_files = TableFilesFactory.create_batch(raw_cnt, table=table, file_type=TableFiles.FILE_TYPE_RAW, date=120)
backup_files = TableFilesFactory.create_batch(backup_cnt, table=table, file_type=TableFiles.FILE_TYPE_BACKUP, date=110)
index_files = TableFilesFactory.create_batch(index_cnt, table=table, file_type=TableFiles.FILE_TYPE_INDEX, date=110)
new_index_files = TableFilesFactory.create_batch(new_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_INDEX, date=110)
new_merge_files = TableFilesFactory.create_batch(new_merge_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_MERGE, date=110)
to_delete_files = TableFilesFactory.create_batch(to_delete_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_DELETE, date=110)
assert table.files_to_search().count() == raw_cnt + index_cnt + to_index_cnt
assert table.files_to_search([(100, 115)]).count() == index_cnt + to_index_cnt
assert table.files_to_search([(111, 120)]).count() == 0
assert table.files_to_search([(111, 121)]).count() == raw_cnt
assert table.files_to_search([(110, 121)]).count() == raw_cnt + index_cnt + to_index_cnt

View File

@ -0,0 +1,279 @@
import logging
import pytest
import mock
import datetime
import random
import faker
import inspect
from milvus import Milvus
from milvus.client.types import Status, IndexType, MetricType
from milvus.client.abstract import IndexParam, TableSchema
from milvus.grpc_gen import status_pb2, milvus_pb2
from mishards import db, create_app, settings
from mishards.service_handler import ServiceHandler
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables
from mishards.router import RouterMixin
logger = logging.getLogger(__name__)
OK = Status(code=Status.SUCCESS, message='Success')
BAD = Status(code=Status.PERMISSION_DENIED, message='Fail')
@pytest.mark.usefixtures('started_app')
class TestServer:
@property
def client(self):
m = Milvus()
m.connect(host='localhost', port=settings.SERVER_TEST_PORT)
return m
def test_server_start(self, started_app):
assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER
def test_cmd(self, started_app):
ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK,
''))
status, _ = self.client.server_version()
assert status.OK()
Parser.parse_proto_Command = mock.MagicMock(return_value=(BAD, 'cmd'))
status, _ = self.client.server_version()
assert not status.OK()
def test_drop_index(self, started_app):
table_name = inspect.currentframe().f_code.co_name
ServiceHandler._drop_index = mock.MagicMock(return_value=OK)
status = self.client.drop_index(table_name)
assert status.OK()
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status = self.client.drop_index(table_name)
assert not status.OK()
def test_describe_index(self, started_app):
table_name = inspect.currentframe().f_code.co_name
index_type = IndexType.FLAT
nlist = 1
index_param = IndexParam(table_name=table_name,
index_type=index_type,
nlist=nlist)
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._describe_index = mock.MagicMock(
return_value=(OK, index_param))
status, ret = self.client.describe_index(table_name)
assert status.OK()
assert ret._table_name == index_param._table_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status, _ = self.client.describe_index(table_name)
assert not status.OK()
def test_preload(self, started_app):
table_name = inspect.currentframe().f_code.co_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._preload_table = mock.MagicMock(return_value=OK)
status = self.client.preload_table(table_name)
assert status.OK()
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status = self.client.preload_table(table_name)
assert not status.OK()
@pytest.mark.skip
def test_delete_by_range(self, started_app):
table_name = inspect.currentframe().f_code.co_name
unpacked = table_name, datetime.datetime.today(
), datetime.datetime.today()
Parser.parse_proto_DeleteByRangeParam = mock.MagicMock(
return_value=(OK, unpacked))
ServiceHandler._delete_by_range = mock.MagicMock(return_value=OK)
status = self.client.delete_vectors_by_range(
*unpacked)
assert status.OK()
Parser.parse_proto_DeleteByRangeParam = mock.MagicMock(
return_value=(BAD, unpacked))
status = self.client.delete_vectors_by_range(
*unpacked)
assert not status.OK()
def test_count_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
count = random.randint(100, 200)
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._count_table = mock.MagicMock(return_value=(OK, count))
status, ret = self.client.get_table_row_count(table_name)
assert status.OK()
assert ret == count
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status, _ = self.client.get_table_row_count(table_name)
assert not status.OK()
def test_show_tables(self, started_app):
tables = ['t1', 't2']
ServiceHandler._show_tables = mock.MagicMock(return_value=(OK, tables))
status, ret = self.client.show_tables()
assert status.OK()
assert ret == tables
def test_describe_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
dimension = 128
nlist = 1
table_schema = TableSchema(table_name=table_name,
index_file_size=100,
metric_type=MetricType.L2,
dimension=dimension)
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_schema.table_name))
ServiceHandler._describe_table = mock.MagicMock(
return_value=(OK, table_schema))
status, _ = self.client.describe_table(table_name)
assert status.OK()
ServiceHandler._describe_table = mock.MagicMock(
return_value=(BAD, table_schema))
status, _ = self.client.describe_table(table_name)
assert not status.OK()
Parser.parse_proto_TableName = mock.MagicMock(return_value=(BAD,
'cmd'))
status, ret = self.client.describe_table(table_name)
assert not status.OK()
def test_insert(self, started_app):
table_name = inspect.currentframe().f_code.co_name
vectors = [[random.random() for _ in range(16)] for _ in range(10)]
ids = [random.randint(1000000, 20000000) for _ in range(10)]
ServiceHandler._add_vectors = mock.MagicMock(return_value=(OK, ids))
status, ret = self.client.add_vectors(
table_name=table_name, records=vectors)
assert status.OK()
assert ids == ret
def test_create_index(self, started_app):
table_name = inspect.currentframe().f_code.co_name
unpacks = table_name, None
Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(OK,
unpacks))
ServiceHandler._create_index = mock.MagicMock(return_value=OK)
status = self.client.create_index(table_name=table_name)
assert status.OK()
Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(BAD,
None))
status = self.client.create_index(table_name=table_name)
assert not status.OK()
def test_drop_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._delete_table = mock.MagicMock(return_value=OK)
status = self.client.delete_table(table_name=table_name)
assert status.OK()
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status = self.client.delete_table(table_name=table_name)
assert not status.OK()
def test_has_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._has_table = mock.MagicMock(return_value=(OK, True))
has = self.client.has_table(table_name=table_name)
assert has
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status, has = self.client.has_table(table_name=table_name)
assert not status.OK()
assert not has
def test_create_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
dimension = 128
table_schema = dict(table_name=table_name,
index_file_size=100,
metric_type=MetricType.L2,
dimension=dimension)
ServiceHandler._create_table = mock.MagicMock(return_value=OK)
status = self.client.create_table(table_schema)
assert status.OK()
Parser.parse_proto_TableSchema = mock.MagicMock(return_value=(BAD,
None))
status = self.client.create_table(table_schema)
assert not status.OK()
def random_data(self, n, dimension):
return [[random.random() for _ in range(dimension)] for _ in range(n)]
def test_search(self, started_app):
table_name = inspect.currentframe().f_code.co_name
to_index_cnt = random.randint(10, 20)
table = TablesFactory(table_id=table_name, state=Tables.NORMAL)
to_index_files = TableFilesFactory.create_batch(
to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX)
topk = random.randint(5, 10)
nq = random.randint(5, 10)
param = {
'table_name': table_name,
'query_records': self.random_data(nq, table.dimension),
'top_k': topk,
'nprobe': 2049
}
result = [
milvus_pb2.TopKQueryResult(query_result_arrays=[
milvus_pb2.QueryResult(id=i, distance=random.random())
for i in range(topk)
]) for i in range(nq)
]
mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status(
error_code=status_pb2.SUCCESS, reason="Success"),
topk_query_result=result)
table_schema = TableSchema(table_name=table_name,
index_file_size=table.index_file_size,
metric_type=table.metric_type,
dimension=table.dimension)
status, _ = self.client.search_vectors(**param)
assert status.code == Status.ILLEGAL_ARGUMENT
param['nprobe'] = 2048
RouterMixin.connection = mock.MagicMock(return_value=Milvus())
RouterMixin.query_conn = mock.MagicMock(return_value=Milvus())
Milvus.describe_table = mock.MagicMock(return_value=(BAD,
table_schema))
status, ret = self.client.search_vectors(**param)
assert status.code == Status.TABLE_NOT_EXISTS
Milvus.describe_table = mock.MagicMock(return_value=(OK, table_schema))
Milvus.search_vectors_in_files = mock.MagicMock(
return_value=mock_results)
status, ret = self.client.search_vectors(**param)
assert status.OK()
assert len(ret) == nq

View File

@ -0,0 +1,20 @@
import datetime
from mishards import exceptions
def format_date(start, end):
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day,
(end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
def range_to_date(range_obj, metadata=None):
try:
start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d')
end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d')
assert start < end
except (ValueError, AssertionError):
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
range_obj.start_date, range_obj.end_date),
metadata=metadata)
return format_date(start, end)

37
shards/requirements.txt Normal file
View File

@ -0,0 +1,37 @@
environs==4.2.0
factory-boy==2.12.0
Faker==1.0.7
fire==0.1.3
google-auth==1.6.3
grpcio==1.22.0
grpcio-tools==1.22.0
kubernetes==10.0.1
MarkupSafe==1.1.1
marshmallow==2.19.5
pymysql==0.9.3
protobuf==3.9.1
py==1.8.0
pyasn1==0.4.7
pyasn1-modules==0.2.6
pylint==2.3.1
pymilvus-test==0.2.28
#pymilvus==0.2.0
pyparsing==2.4.0
pytest==4.6.3
pytest-level==0.1.1
pytest-print==0.1.2
pytest-repeat==0.8.0
pytest-timeout==1.3.3
python-dateutil==2.8.0
python-dotenv==0.10.3
pytz==2019.1
requests==2.22.0
requests-oauthlib==1.2.0
rsa==4.0
six==1.12.0
SQLAlchemy==1.3.5
urllib3==1.25.3
jaeger-client>=3.4.0
grpcio-opentracing>=1.0
mock==2.0.0
pluginbase==1.0.0

4
shards/setup.cfg Normal file
View File

@ -0,0 +1,4 @@
[tool:pytest]
testpaths = mishards
log_cli=true
log_cli_level=info

43
shards/tracer/__init__.py Normal file
View File

@ -0,0 +1,43 @@
from contextlib import contextmanager
def empty_server_interceptor_decorator(target_server, interceptor):
return target_server
@contextmanager
def EmptySpan(*args, **kwargs):
yield None
return
class Tracer:
def __init__(self,
tracer=None,
interceptor=None,
server_decorator=empty_server_interceptor_decorator):
self.tracer = tracer
self.interceptor = interceptor
self.server_decorator = server_decorator
def decorate(self, server):
return self.server_decorator(server, self.interceptor)
@property
def empty(self):
return self.tracer is None
def close(self):
self.tracer and self.tracer.close()
def start_span(self,
operation_name=None,
child_of=None,
references=None,
tags=None,
start_time=None,
ignore_active_span=False):
if self.empty:
return EmptySpan()
return self.tracer.start_span(operation_name, child_of, references,
tags, start_time, ignore_active_span)

27
shards/tracer/factory.py Normal file
View File

@ -0,0 +1,27 @@
import os
import logging
from tracer import Tracer
from utils.plugins import BaseMixin
logger = logging.getLogger(__name__)
PLUGIN_PACKAGE_NAME = 'tracer.plugins'
class TracerFactory(BaseMixin):
PLUGIN_TYPE = 'Tracer'
def __init__(self, searchpath=None):
super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME)
def create(self, class_name, **kwargs):
if not class_name:
return Tracer()
return super().create(class_name, **kwargs)
def _create(self, plugin_class, **kwargs):
plugin_config = kwargs.pop('plugin_config', None)
if not plugin_config:
raise RuntimeError('\'{}\' Plugin Config is Required!'.format(self.PLUGIN_TYPE))
plugin = plugin_class.Create(plugin_config=plugin_config, **kwargs)
return plugin

View File

@ -0,0 +1,35 @@
import logging
from jaeger_client import Config
from grpc_opentracing.grpcext import intercept_server
from grpc_opentracing import open_tracing_server_interceptor
from tracer import Tracer
logger = logging.getLogger(__name__)
PLUGIN_NAME = __file__
class JaegerFactory:
name = 'jaeger'
@classmethod
def Create(cls, plugin_config, **kwargs):
tracing_config = plugin_config.TRACING_CONFIG
span_decorator = kwargs.pop('span_decorator', None)
service_name = plugin_config.TRACING_SERVICE_NAME
validate = plugin_config.TRACING_VALIDATE
config = Config(config=tracing_config,
service_name=service_name,
validate=validate)
tracer = config.initialize_tracer()
tracer_interceptor = open_tracing_server_interceptor(
tracer,
log_payloads=plugin_config.TRACING_LOG_PAYLOAD,
span_decorator=span_decorator)
return Tracer(tracer, tracer_interceptor, intercept_server)
def setup(app):
logger.info('Plugin \'{}\' Installed In Package: {}'.format(PLUGIN_NAME, app.plugin_package_name))
app.on_plugin_setup(JaegerFactory)

18
shards/utils/__init__.py Normal file
View File

@ -0,0 +1,18 @@
from functools import wraps
def singleton(cls):
instances = {}
@wraps(cls)
def getinstance(*args, **kw):
if cls not in instances:
instances[cls] = cls(*args, **kw)
return instances[cls]
return getinstance
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

View File

@ -0,0 +1,152 @@
import os
import datetime
from pytz import timezone
from logging import Filter
import logging.config
class InfoFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.INFO
class DebugFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.DEBUG
class WarnFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.WARN
class ErrorFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.ERROR
class CriticalFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.CRITICAL
COLORS = {
'HEADER': '\033[95m',
'INFO': '\033[92m',
'DEBUG': '\033[94m',
'WARNING': '\033[93m',
'ERROR': '\033[95m',
'CRITICAL': '\033[91m',
'ENDC': '\033[0m',
}
class ColorFulFormatColMixin:
def format_col(self, message_str, level_name):
if level_name in COLORS.keys():
message_str = COLORS.get(level_name) + message_str + COLORS.get(
'ENDC')
return message_str
class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin):
def format(self, record):
message_str = super(ColorfulFormatter, self).format(record)
return self.format_col(message_str, level_name=record.levelname)
def config(log_level, log_path, name, tz='UTC'):
def build_log_file(level, log_path, name, tz):
utc_now = datetime.datetime.utcnow()
utc_tz = timezone('UTC')
local_tz = timezone(tz)
tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz)
return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"),
level)
if not os.path.exists(log_path):
os.makedirs(log_path)
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'default': {
'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)',
},
'colorful_console': {
'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)',
'()': ColorfulFormatter,
},
},
'filters': {
'InfoFilter': {
'()': InfoFilter,
},
'DebugFilter': {
'()': DebugFilter,
},
'WarnFilter': {
'()': WarnFilter,
},
'ErrorFilter': {
'()': ErrorFilter,
},
'CriticalFilter': {
'()': CriticalFilter,
},
},
'handlers': {
'milvus_celery_console': {
'class': 'logging.StreamHandler',
'formatter': 'colorful_console',
},
'milvus_debug_file': {
'level': 'DEBUG',
'filters': ['DebugFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('debug', log_path, name, tz)
},
'milvus_info_file': {
'level': 'INFO',
'filters': ['InfoFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('info', log_path, name, tz)
},
'milvus_warn_file': {
'level': 'WARN',
'filters': ['WarnFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('warn', log_path, name, tz)
},
'milvus_error_file': {
'level': 'ERROR',
'filters': ['ErrorFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('error', log_path, name, tz)
},
'milvus_critical_file': {
'level': 'CRITICAL',
'filters': ['CriticalFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('critical', log_path, name, tz)
},
},
'loggers': {
'': {
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file',
'milvus_error_file', 'milvus_critical_file'],
'level': log_level,
'propagate': False
},
},
'propagate': False,
}
logging.config.dictConfig(LOGGING)

View File

@ -0,0 +1,16 @@
import importlib.util
from pluginbase import PluginBase, PluginSource
class MiPluginSource(PluginSource):
def load_plugin(self, name):
plugin = super().load_plugin(name)
spec = importlib.util.spec_from_file_location(self.base.package + '.' + name, plugin.__file__)
plugin = importlib.util.module_from_spec(spec)
spec.loader.exec_module(plugin)
return plugin
class MiPluginBase(PluginBase):
def make_plugin_source(self, *args, **kwargs):
return MiPluginSource(self, *args, **kwargs)

View File

@ -0,0 +1,40 @@
import os
import inspect
from functools import partial
from utils.pluginextension import MiPluginBase as PluginBase
class BaseMixin(object):
def __init__(self, package_name, searchpath=None):
self.plugin_package_name = package_name
caller_path = os.path.dirname(inspect.stack()[1][1])
get_path = partial(os.path.join, caller_path)
plugin_base = PluginBase(package=self.plugin_package_name,
searchpath=[get_path('./plugins')])
self.class_map = {}
searchpath = searchpath if searchpath else []
searchpath = [searchpath] if isinstance(searchpath, str) else searchpath
self.source = plugin_base.make_plugin_source(searchpath=searchpath,
identifier=self.__class__.__name__)
for plugin_name in self.source.list_plugins():
plugin = self.source.load_plugin(plugin_name)
plugin.setup(self)
def on_plugin_setup(self, plugin_class):
name = getattr(plugin_class, 'name', plugin_class.__name__)
self.class_map[name.lower()] = plugin_class
def plugin(self, name):
return self.class_map.get(name, None)
def create(self, class_name, **kwargs):
if not class_name:
raise RuntimeError('Please specify \'{}\' class_name first!'.format(self.PLUGIN_TYPE))
plugin_class = self.plugin(class_name.lower())
if not plugin_class:
raise RuntimeError('{} Plugin \'{}\' Not Installed!'.format(self.PLUGIN_TYPE, class_name))
return self._create(plugin_class, **kwargs)