File size: 5,088 Bytes
edcffc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a44fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    NormalizeIntensityd,
    Activationsd,
    AsDiscreted,
    ScaleIntensityd,
)

# Define the bundle name and path for downloading
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)

# Title and description
title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! 🧠 </h1>'
description = """
## 🚀 To run

Upload a brain MRI image file, or try out one of the examples below! 
If you want to see a different slice, update the slider.

More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0)

## ⚠️ Disclaimer

This is an example, not to be used for diagnostic purposes.
"""

references = """
## 👀 References

1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""

examples = [
    ['examples/BRATS_485.nii.gz', 65],
    ['examples/BRATS_486.nii.gz', 80]
]

# Load the MONAI pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
    name = BUNDLE_NAME,
    source = 'huggingface_hub',
    repo = 'katielink/brats_mri_segmentation_v0.1.0',
    load_ts_module=True,
)

# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load the parser from the MONAI bundle's inference config
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')

# Compose the preprocessing transforms
preproc_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

# Get the inferer from the bundle's inference config
inferer = parser.get_parsed_content(
    'inferer',
    lazy=True, eval_expr=True, instantiate=True
)

# Compose the postprocessing transforms
post_transforms = Compose(
    [
        Activationsd(keys='pred', sigmoid=True),
        AsDiscreted(keys='pred', threshold=0.5),
        ScaleIntensityd(keys='image', minv=0., maxv=1.)
    ]
)


# Define the predict function for the demo
def predict(input_file, z_axis, model=model, device=device):
    # Load and process data in MONAI format
    data = {'image': [input_file.name]}
    data = preproc_transforms(data)
    
    # Run inference and post-process predicted labels
    model.to(device)
    model.eval()
    with torch.no_grad():
        inputs = data['image'].to(device)
        data['pred'] = inferer(inputs=inputs[None,...], network=model)
        data = post_transforms(data)
    
    # Convert tensors back to numpy arrays
    data['image'] = data['image'].numpy()
    data['pred'] = data['pred'].cpu().detach().numpy()
    
    # Magnetic resonance imaging sequences
    t1c = data['image'][0, :, :, z_axis]  # T1-weighted, post contrast
    t1 = data['image'][1, :, :, z_axis]  # T1-weighted, pre contrast
    t2 = data['image'][2, :, :, z_axis]  # T2-weighted
    flair = data['image'][3, :, :, z_axis]  # FLAIR
    
    # BraTS labels
    tc = data['pred'][0, 0, :, :, z_axis]  # Tumor core
    wt = data['pred'][0, 1, :, :, z_axis]  # Whole tumor
    et = data['pred'][0, 2, :, :, z_axis]  # Enhancing tumor
    
    return [t1c, t1, t2, flair], [tc, wt, et] 


# Use blocks to set up a more complex demo
with gr.Blocks() as demo:

    with gr.Row():
        # Get the input file and slice slider as inputs
        input_file = gr.File(label='input file')
        z_axis = gr.Slider(0, 200, label='slice', value=50)
        
    with gr.Row():
        # Show the button with custom label
        button = gr.Button("Segment Tumor!")

    with gr.Row():
        with gr.Column():
            # Show the input image with different MR sequences
            input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)')
        
        with gr.Column():
            # Show the segmentation labels
            output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)')

    
    # Run prediction on button click
    button.click(
        predict, 
        inputs=[input_file, z_axis],
        outputs=[input_image, output_segmentation]
    )

    # Have some example for the user to try out
    examples = gr.Examples(
        examples=examples,
        inputs=[input_file, z_axis],
        outputs=[input_image, output_segmentation],
        fn=predict,
        cache_examples=False
    )

# Launch the demo 
demo.launch()