Core-AI-IMAGE / app.py
Leyogho's picture
Core
edebe10
raw
history blame
3.56 kB
# Imports standard
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import gradio as gr
import os
# Imports Hugging Face
from huggingface_hub import hf_hub_download, login
from google.colab import userdata
# Imports locaux
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
from inference_utils.inference import interactive_infer_image
from inference_utils.output_processing import check_mask_stats
from inference_utils.processing_utils import read_rgb, get_instances
def init_huggingface():
"""Initialise la connexion Hugging Face et télécharge le modèle."""
login(userdata.get('HF_TOKEN'))
return hf_hub_download(
repo_id="microsoft/BiomedParse",
filename="biomedparse_v1.pt",
local_dir="pretrained"
)
def setup_model():
"""Configure et retourne le modèle."""
opt = init_distributed(opt)
model = BaseModel(opt, build_model(opt)).from_pretrained('hf_hub:microsoft/BiomedParse').eval().cuda()
with torch.no_grad():
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
BIOMED_CLASSES + ["background"],
is_eval=True
)
return model
def process_image(image, prompts, model):
"""Traite l'image avec les prompts donnés."""
if isinstance(image, str):
image = Image.open(image)
else:
image = Image.fromarray(image)
prompts = [p.strip() for p in prompts.split(',')]
pred_masks = interactive_infer_image(model, image, prompts)
fig = plt.figure(figsize=(10, 5))
plt.subplot(1, len(pred_masks) + 1, 1)
plt.imshow(image)
plt.title('Image originale')
plt.axis('off')
for i, mask in enumerate(pred_masks):
plt.subplot(1, len(pred_masks) + 1, i+2)
plt.imshow(image)
plt.imshow(mask, alpha=0.5, cmap='Reds')
plt.title(prompts[i])
plt.axis('off')
return fig
def setup_gradio_interface(model):
"""Configure l'interface Gradio."""
return gr.Interface(
theme=gr.Theme.from_hub("allenai/gradio-theme"),
fn=lambda img, txt: process_image(img, txt, model),
inputs=[
gr.Image(type="numpy", label="Image médicale"),
gr.Textbox(
label="Prompts (séparés par des virgules)",
placeholder="edema, lesion, etc...",
elem_classes="white"
)
],
outputs=gr.Plot(),
title="Core IA - Traitement d'image medicale",
description="Chargez une image médicale et spécifiez les éléments à segmenter",
examples=[
["examples/144DME_as_F.jpeg", "Dans cette image donne moi l'œdème"],
["examples/ISIC_0015551.jpg", "Cherche une lésion"],
["examples/T0011.jpg", "disque optique, cupule optique"],
["examples/C3_EndoCV2021_00462.jpg", "Trouve moi le polyp"],
["examples/covid_1585.png", "Qu'est ce qui ne va pas ici ?"],
['examples/Part_1_516_pathology_breast.png', "cellules néoplasiques , cellules inflammatoires , cellules du tissu conjonctif"]
]
)
def main():
"""Point d'entrée principal de l'application."""
init_huggingface()
model = setup_model()
interface = setup_gradio_interface(model)
interface.launch(debug=True)
if __name__ == "__main__":
main()