File size: 3,561 Bytes
edebe10
 
 
 
 
f53de7a
edebe10
f53de7a
edebe10
 
 
f53de7a
edebe10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Imports standard
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import gradio as gr
import os

# Imports Hugging Face
from huggingface_hub import hf_hub_download, login
from google.colab import userdata

# Imports locaux
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
from inference_utils.inference import interactive_infer_image
from inference_utils.output_processing import check_mask_stats
from inference_utils.processing_utils import read_rgb, get_instances

def init_huggingface():
    """Initialise la connexion Hugging Face et télécharge le modèle."""
    login(userdata.get('HF_TOKEN'))
    return hf_hub_download(
        repo_id="microsoft/BiomedParse",
        filename="biomedparse_v1.pt",
        local_dir="pretrained"
    )

def setup_model():
    """Configure et retourne le modèle."""

    opt = init_distributed(opt)
    model = BaseModel(opt, build_model(opt)).from_pretrained('hf_hub:microsoft/BiomedParse').eval().cuda()
    
    with torch.no_grad():
        model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
            BIOMED_CLASSES + ["background"],
            is_eval=True
        )
    return model

def process_image(image, prompts, model):
    """Traite l'image avec les prompts donnés."""
    if isinstance(image, str):
        image = Image.open(image)
    else:
        image = Image.fromarray(image)
    
    prompts = [p.strip() for p in prompts.split(',')]
    
    pred_masks = interactive_infer_image(model, image, prompts)
    
    fig = plt.figure(figsize=(10, 5))
    plt.subplot(1, len(pred_masks) + 1, 1)
    plt.imshow(image)
    plt.title('Image originale')
    plt.axis('off')

    for i, mask in enumerate(pred_masks):
        plt.subplot(1, len(pred_masks) + 1, i+2)
        plt.imshow(image)
        plt.imshow(mask, alpha=0.5, cmap='Reds')
        plt.title(prompts[i])
        plt.axis('off')
    
    return fig

def setup_gradio_interface(model):
    """Configure l'interface Gradio."""
    return gr.Interface(
        theme=gr.Theme.from_hub("allenai/gradio-theme"),
        fn=lambda img, txt: process_image(img, txt, model),
        inputs=[
            gr.Image(type="numpy", label="Image médicale"),
            gr.Textbox(
                label="Prompts (séparés par des virgules)",
                placeholder="edema, lesion, etc...",
                elem_classes="white"
            )
        ],
        outputs=gr.Plot(),
        title="Core IA - Traitement d'image medicale",
        description="Chargez une image médicale et spécifiez les éléments à segmenter",
        examples=[
            ["examples/144DME_as_F.jpeg", "Dans cette image donne moi l'œdème"],
            ["examples/ISIC_0015551.jpg", "Cherche une lésion"],
            ["examples/T0011.jpg", "disque optique, cupule optique"],
            ["examples/C3_EndoCV2021_00462.jpg", "Trouve moi le polyp"],
            ["examples/covid_1585.png", "Qu'est ce qui ne va pas ici ?"],
            ['examples/Part_1_516_pathology_breast.png', "cellules néoplasiques , cellules inflammatoires ,  cellules du tissu conjonctif"]
        ]
    )

def main():
    """Point d'entrée principal de l'application."""
    init_huggingface()
    model = setup_model()
    interface = setup_gradio_interface(model)
    interface.launch(debug=True)

if __name__ == "__main__":
    main()