Adds ROI
parent
c2f2334766
commit
6f2ced8b9c
|
|
@ -18,6 +18,17 @@ DEEPSTACK_TIMEOUT = int(os.getenv("DEEPSTACK_TIMEOUT", 10))
|
|||
DEFAULT_CONFIDENCE_THRESHOLD = 0
|
||||
TEST_IMAGE = "street.jpg"
|
||||
|
||||
DEFAULT_ROI_Y_MIN = 0.0
|
||||
DEFAULT_ROI_Y_MAX = 1.0
|
||||
DEFAULT_ROI_X_MIN = 0.0
|
||||
DEFAULT_ROI_X_MAX = 1.0
|
||||
DEFAULT_ROI = (
|
||||
DEFAULT_ROI_Y_MIN,
|
||||
DEFAULT_ROI_X_MIN,
|
||||
DEFAULT_ROI_Y_MAX,
|
||||
DEFAULT_ROI_X_MAX,
|
||||
)
|
||||
|
||||
predictions = None
|
||||
|
||||
|
||||
|
|
@ -29,6 +40,7 @@ def process_image(pil_image, dsobject):
|
|||
return predictions
|
||||
|
||||
|
||||
## Setup sidebar
|
||||
st.title("Deepstack Object detection")
|
||||
img_file_buffer = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
|
||||
|
||||
|
|
@ -36,12 +48,33 @@ st.sidebar.title("Parameters")
|
|||
CONFIDENCE_THRESHOLD = st.sidebar.slider(
|
||||
"Confidence threshold", 0, 100, DEFAULT_CONFIDENCE_THRESHOLD, 1
|
||||
)
|
||||
|
||||
CLASSES_TO_INCLUDE = st.sidebar.multiselect(
|
||||
"Select object classes to include",
|
||||
options=const.CLASSES,
|
||||
default=const.DEFAULT_CLASSES,
|
||||
)
|
||||
|
||||
# Get ROI info
|
||||
st.sidebar.title("ROI")
|
||||
ROI_X_MIN = st.sidebar.slider("x_min", 0.0, 1.0, DEFAULT_ROI_X_MIN)
|
||||
ROI_Y_MIN = st.sidebar.slider("y_min", 0.0, 1.0, DEFAULT_ROI_Y_MIN)
|
||||
ROI_X_MAX = st.sidebar.slider("x_max", 0.0, 1.0, DEFAULT_ROI_X_MAX)
|
||||
ROI_Y_MAX = st.sidebar.slider("y_max", 0.0, 1.0, DEFAULT_ROI_Y_MAX)
|
||||
ROI_TUPLE = (
|
||||
ROI_Y_MIN,
|
||||
ROI_X_MIN,
|
||||
ROI_Y_MAX,
|
||||
ROI_X_MAX,
|
||||
)
|
||||
ROI_DICT = {
|
||||
"x_min": ROI_X_MIN,
|
||||
"y_min": ROI_Y_MIN,
|
||||
"x_max": ROI_X_MAX,
|
||||
"y_max": ROI_Y_MAX,
|
||||
}
|
||||
|
||||
## Process image
|
||||
if img_file_buffer is not None:
|
||||
pil_image = Image.open(img_file_buffer)
|
||||
|
||||
|
|
@ -59,7 +92,9 @@ all_objects_names = set([obj["name"] for obj in objects])
|
|||
# Filter objects for display
|
||||
objects = [obj for obj in objects if obj["confidence"] > CONFIDENCE_THRESHOLD]
|
||||
objects = [obj for obj in objects if obj["name"] in CLASSES_TO_INCLUDE]
|
||||
objects = [obj for obj in objects if utils.object_in_roi(ROI_DICT, obj["centroid"])]
|
||||
|
||||
# Draw object boxes
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
for obj in objects:
|
||||
name = obj["name"]
|
||||
|
|
@ -76,13 +111,25 @@ for obj in objects:
|
|||
color=const.YELLOW,
|
||||
)
|
||||
|
||||
# Draw ROI box
|
||||
if ROI_TUPLE != DEFAULT_ROI:
|
||||
utils.draw_box(
|
||||
draw,
|
||||
ROI_TUPLE,
|
||||
pil_image.width,
|
||||
pil_image.height,
|
||||
text="ROI",
|
||||
color=const.GREEN,
|
||||
)
|
||||
|
||||
# Display image and results
|
||||
st.image(
|
||||
np.array(pil_image), caption=f"Processed image", use_column_width=True,
|
||||
)
|
||||
st.subheader("All discovered objects")
|
||||
st.write(all_objects_names)
|
||||
|
||||
st.subheader("Object count")
|
||||
st.subheader("Filtered object count")
|
||||
obj_types = list(set([obj["name"] for obj in objects]))
|
||||
for obj_type in obj_types:
|
||||
obj_type_count = len([obj for obj in objects if obj["name"] == obj_type])
|
||||
|
|
|
|||
Loading…
Reference in New Issue