Spaces:
Runtime error
Runtime error
File size: 5,088 Bytes
edcffc5 58a44fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
NormalizeIntensityd,
Activationsd,
AsDiscreted,
ScaleIntensityd,
)
# Define the bundle name and path for downloading
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
# Title and description
title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! 🧠 </h1>'
description = """
## 🚀 To run
Upload a brain MRI image file, or try out one of the examples below!
If you want to see a different slice, update the slider.
More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0)
## ⚠️ Disclaimer
This is an example, not to be used for diagnostic purposes.
"""
references = """
## 👀 References
1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""
examples = [
['examples/BRATS_485.nii.gz', 65],
['examples/BRATS_486.nii.gz', 80]
]
# Load the MONAI pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
name = BUNDLE_NAME,
source = 'huggingface_hub',
repo = 'katielink/brats_mri_segmentation_v0.1.0',
load_ts_module=True,
)
# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the parser from the MONAI bundle's inference config
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
# Compose the preprocessing transforms
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
# Get the inferer from the bundle's inference config
inferer = parser.get_parsed_content(
'inferer',
lazy=True, eval_expr=True, instantiate=True
)
# Compose the postprocessing transforms
post_transforms = Compose(
[
Activationsd(keys='pred', sigmoid=True),
AsDiscreted(keys='pred', threshold=0.5),
ScaleIntensityd(keys='image', minv=0., maxv=1.)
]
)
# Define the predict function for the demo
def predict(input_file, z_axis, model=model, device=device):
# Load and process data in MONAI format
data = {'image': [input_file.name]}
data = preproc_transforms(data)
# Run inference and post-process predicted labels
model.to(device)
model.eval()
with torch.no_grad():
inputs = data['image'].to(device)
data['pred'] = inferer(inputs=inputs[None,...], network=model)
data = post_transforms(data)
# Convert tensors back to numpy arrays
data['image'] = data['image'].numpy()
data['pred'] = data['pred'].cpu().detach().numpy()
# Magnetic resonance imaging sequences
t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast
t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast
t2 = data['image'][2, :, :, z_axis] # T2-weighted
flair = data['image'][3, :, :, z_axis] # FLAIR
# BraTS labels
tc = data['pred'][0, 0, :, :, z_axis] # Tumor core
wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor
et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor
return [t1c, t1, t2, flair], [tc, wt, et]
# Use blocks to set up a more complex demo
with gr.Blocks() as demo:
with gr.Row():
# Get the input file and slice slider as inputs
input_file = gr.File(label='input file')
z_axis = gr.Slider(0, 200, label='slice', value=50)
with gr.Row():
# Show the button with custom label
button = gr.Button("Segment Tumor!")
with gr.Row():
with gr.Column():
# Show the input image with different MR sequences
input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)')
with gr.Column():
# Show the segmentation labels
output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)')
# Run prediction on button click
button.click(
predict,
inputs=[input_file, z_axis],
outputs=[input_image, output_segmentation]
)
# Have some example for the user to try out
examples = gr.Examples(
examples=examples,
inputs=[input_file, z_axis],
outputs=[input_image, output_segmentation],
fn=predict,
cache_examples=False
)
# Launch the demo
demo.launch()
|