Spaces:
Build error
Build error
# Imports standard | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import os | |
# Imports Hugging Face | |
from huggingface_hub import hf_hub_download, login | |
from google.colab import userdata | |
# Imports locaux | |
from modeling.BaseModel import BaseModel | |
from modeling import build_model | |
from utilities.distributed import init_distributed | |
from utilities.arguments import load_opt_from_config_files | |
from utilities.constants import BIOMED_CLASSES | |
from inference_utils.inference import interactive_infer_image | |
from inference_utils.output_processing import check_mask_stats | |
from inference_utils.processing_utils import read_rgb, get_instances | |
def init_huggingface(): | |
"""Initialise la connexion Hugging Face et télécharge le modèle.""" | |
login(userdata.get('HF_TOKEN')) | |
return hf_hub_download( | |
repo_id="microsoft/BiomedParse", | |
filename="biomedparse_v1.pt", | |
local_dir="pretrained" | |
) | |
def setup_model(): | |
"""Configure et retourne le modèle.""" | |
opt = init_distributed(opt) | |
model = BaseModel(opt, build_model(opt)).from_pretrained('hf_hub:microsoft/BiomedParse').eval().cuda() | |
with torch.no_grad(): | |
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings( | |
BIOMED_CLASSES + ["background"], | |
is_eval=True | |
) | |
return model | |
def process_image(image, prompts, model): | |
"""Traite l'image avec les prompts donnés.""" | |
if isinstance(image, str): | |
image = Image.open(image) | |
else: | |
image = Image.fromarray(image) | |
prompts = [p.strip() for p in prompts.split(',')] | |
pred_masks = interactive_infer_image(model, image, prompts) | |
fig = plt.figure(figsize=(10, 5)) | |
plt.subplot(1, len(pred_masks) + 1, 1) | |
plt.imshow(image) | |
plt.title('Image originale') | |
plt.axis('off') | |
for i, mask in enumerate(pred_masks): | |
plt.subplot(1, len(pred_masks) + 1, i+2) | |
plt.imshow(image) | |
plt.imshow(mask, alpha=0.5, cmap='Reds') | |
plt.title(prompts[i]) | |
plt.axis('off') | |
return fig | |
def setup_gradio_interface(model): | |
"""Configure l'interface Gradio.""" | |
return gr.Interface( | |
theme=gr.Theme.from_hub("allenai/gradio-theme"), | |
fn=lambda img, txt: process_image(img, txt, model), | |
inputs=[ | |
gr.Image(type="numpy", label="Image médicale"), | |
gr.Textbox( | |
label="Prompts (séparés par des virgules)", | |
placeholder="edema, lesion, etc...", | |
elem_classes="white" | |
) | |
], | |
outputs=gr.Plot(), | |
title="Core IA - Traitement d'image medicale", | |
description="Chargez une image médicale et spécifiez les éléments à segmenter", | |
examples=[ | |
["examples/144DME_as_F.jpeg", "Dans cette image donne moi l'œdème"], | |
["examples/ISIC_0015551.jpg", "Cherche une lésion"], | |
["examples/T0011.jpg", "disque optique, cupule optique"], | |
["examples/C3_EndoCV2021_00462.jpg", "Trouve moi le polyp"], | |
["examples/covid_1585.png", "Qu'est ce qui ne va pas ici ?"], | |
['examples/Part_1_516_pathology_breast.png', "cellules néoplasiques , cellules inflammatoires , cellules du tissu conjonctif"] | |
] | |
) | |
def main(): | |
"""Point d'entrée principal de l'application.""" | |
init_huggingface() | |
model = setup_model() | |
interface = setup_gradio_interface(model) | |
interface.launch(debug=True) | |
if __name__ == "__main__": | |
main() |