diff --git a/app/streamlit-ui.py b/app/streamlit-ui.py index 53a5c1e..f4eadf7 100644 --- a/app/streamlit-ui.py +++ b/app/streamlit-ui.py @@ -18,6 +18,17 @@ DEEPSTACK_TIMEOUT = int(os.getenv("DEEPSTACK_TIMEOUT", 10)) DEFAULT_CONFIDENCE_THRESHOLD = 0 TEST_IMAGE = "street.jpg" +DEFAULT_ROI_Y_MIN = 0.0 +DEFAULT_ROI_Y_MAX = 1.0 +DEFAULT_ROI_X_MIN = 0.0 +DEFAULT_ROI_X_MAX = 1.0 +DEFAULT_ROI = ( + DEFAULT_ROI_Y_MIN, + DEFAULT_ROI_X_MIN, + DEFAULT_ROI_Y_MAX, + DEFAULT_ROI_X_MAX, +) + predictions = None @@ -29,6 +40,7 @@ def process_image(pil_image, dsobject): return predictions +## Setup sidebar st.title("Deepstack Object detection") img_file_buffer = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) @@ -36,12 +48,33 @@ st.sidebar.title("Parameters") CONFIDENCE_THRESHOLD = st.sidebar.slider( "Confidence threshold", 0, 100, DEFAULT_CONFIDENCE_THRESHOLD, 1 ) + CLASSES_TO_INCLUDE = st.sidebar.multiselect( "Select object classes to include", options=const.CLASSES, default=const.DEFAULT_CLASSES, ) +# Get ROI info +st.sidebar.title("ROI") +ROI_X_MIN = st.sidebar.slider("x_min", 0.0, 1.0, DEFAULT_ROI_X_MIN) +ROI_Y_MIN = st.sidebar.slider("y_min", 0.0, 1.0, DEFAULT_ROI_Y_MIN) +ROI_X_MAX = st.sidebar.slider("x_max", 0.0, 1.0, DEFAULT_ROI_X_MAX) +ROI_Y_MAX = st.sidebar.slider("y_max", 0.0, 1.0, DEFAULT_ROI_Y_MAX) +ROI_TUPLE = ( + ROI_Y_MIN, + ROI_X_MIN, + ROI_Y_MAX, + ROI_X_MAX, +) +ROI_DICT = { + "x_min": ROI_X_MIN, + "y_min": ROI_Y_MIN, + "x_max": ROI_X_MAX, + "y_max": ROI_Y_MAX, +} + +## Process image if img_file_buffer is not None: pil_image = Image.open(img_file_buffer) @@ -59,7 +92,9 @@ all_objects_names = set([obj["name"] for obj in objects]) # Filter objects for display objects = [obj for obj in objects if obj["confidence"] > CONFIDENCE_THRESHOLD] objects = [obj for obj in objects if obj["name"] in CLASSES_TO_INCLUDE] +objects = [obj for obj in objects if utils.object_in_roi(ROI_DICT, obj["centroid"])] +# Draw object boxes draw = ImageDraw.Draw(pil_image) for obj in objects: name = obj["name"] @@ -76,13 +111,25 @@ for obj in objects: color=const.YELLOW, ) +# Draw ROI box +if ROI_TUPLE != DEFAULT_ROI: + utils.draw_box( + draw, + ROI_TUPLE, + pil_image.width, + pil_image.height, + text="ROI", + color=const.GREEN, + ) + +# Display image and results st.image( np.array(pil_image), caption=f"Processed image", use_column_width=True, ) st.subheader("All discovered objects") st.write(all_objects_names) -st.subheader("Object count") +st.subheader("Filtered object count") obj_types = list(set([obj["name"] for obj in objects])) for obj_type in obj_types: obj_type_count = len([obj for obj in objects if obj["name"] == obj_type]) diff --git a/usage.png b/usage.png index 5d91a2c..4618c9f 100644 Binary files a/usage.png and b/usage.png differ