""" Main application for RGB detection demo. Any new model should implement the following functions: - load_model(model_path, img_size=640) - inference(model, image) """ import os import glob import spaces import gradio as gr from huggingface_hub import get_token from utils import ( check_image, load_image_from_url, load_badges, FlaggedCounter, ) from flagging import HuggingFaceDatasetSaver import install_private_repos # noqa: F401 from seavision import load_model TITLE = """

🌊 SEA.AI's Vision Demo ✨

Ahoy! Explore our object detection technology! Upload a maritime scene image and click Submit to see the results.

""" FLAG_TXT = "Report Mis-detection" NOTICE = f""" 🚩 See something off? Your feedback makes a difference! Let us know by flagging any outcomes that don't seem right. Click the `{FLAG_TXT}` button to submit the image for review. """ css = """ h1 { text-align: center; display: block; } """ model = load_model("ahoy-RGB-b2") @spaces.GPU def inference(image): """Run inference on image and return annotated image.""" results = model(image) return results.draw(image, diameter=4) # Flagging dataset_name = "SEA-AI/crowdsourced-sea-images" hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name) flagged_counter = FlaggedCounter(dataset_name) theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo) with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo: badges = gr.HTML(load_badges(flagged_counter.count())) title = gr.HTML(TITLE) with gr.Row(): with gr.Column(): img_input = gr.Image( label="input", interactive=True, sources=["upload", "clipboard"] ) img_url = gr.Textbox( lines=1, placeholder="or enter URL to image here", label="input_url", show_label=False, ) with gr.Row(): clear = gr.ClearButton() submit = gr.Button("Submit", variant="primary") with gr.Column(): img_output = gr.Image(label="output", interactive=False) flag = gr.Button(FLAG_TXT, visible=False) notice = gr.Markdown(value=NOTICE, visible=False) examples = gr.Examples( examples=glob.glob("examples/*.jpg"), inputs=img_input, outputs=img_output, fn=inference, cache_examples=True, ) # add components to clear when clear button is clicked clear.add([img_input, img_url, img_output]) # event listeners img_url.change(load_image_from_url, [img_url], img_input) submit.click(check_image, [img_input], None, show_api=False).success( inference, [img_input], img_output, api_name="inference", ) # event listeners with decorators @img_output.change( inputs=[img_input, img_output], outputs=[flag, notice], show_api=False, preprocess=False, show_progress="hidden", ) def _show_hide_flagging(_img_input, _img_output): visible = _img_output and _img_input["orig_name"] not in os.listdir("examples") return { flag: gr.Button(FLAG_TXT, interactive=True, visible=visible), notice: gr.Markdown(value=NOTICE, visible=visible), } # This needs to be called prior to the first call to callback.flag() hf_writer.setup([img_input], "flagged") # Sequential logic when flag button is clicked flag.click(lambda: gr.Info("Thank you for contributing!"), show_api=False).then( lambda: {flag: gr.Button(FLAG_TXT, interactive=False)}, [], [flag], show_api=False, ).then( lambda *args: hf_writer.flag(args), [img_input, flag], [], preprocess=False, show_api=False, ).then( lambda: load_badges(flagged_counter.count()), [], badges, show_api=False ) # called during initial load in browser demo.load(lambda: load_badges(flagged_counter.count()), [], badges, show_api=False) if __name__ == "__main__": demo.queue().launch()