A-PolarBear's picture
Update app.py
58a44fc
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
NormalizeIntensityd,
Activationsd,
AsDiscreted,
ScaleIntensityd,
)
# Define the bundle name and path for downloading
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
# Title and description
title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! 🧠 </h1>'
description = """
## 🚀 To run
Upload a brain MRI image file, or try out one of the examples below!
If you want to see a different slice, update the slider.
More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0)
## ⚠️ Disclaimer
This is an example, not to be used for diagnostic purposes.
"""
references = """
## 👀 References
1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""
examples = [
['examples/BRATS_485.nii.gz', 65],
['examples/BRATS_486.nii.gz', 80]
]
# Load the MONAI pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
name = BUNDLE_NAME,
source = 'huggingface_hub',
repo = 'katielink/brats_mri_segmentation_v0.1.0',
load_ts_module=True,
)
# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the parser from the MONAI bundle's inference config
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
# Compose the preprocessing transforms
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
# Get the inferer from the bundle's inference config
inferer = parser.get_parsed_content(
'inferer',
lazy=True, eval_expr=True, instantiate=True
)
# Compose the postprocessing transforms
post_transforms = Compose(
[
Activationsd(keys='pred', sigmoid=True),
AsDiscreted(keys='pred', threshold=0.5),
ScaleIntensityd(keys='image', minv=0., maxv=1.)
]
)
# Define the predict function for the demo
def predict(input_file, z_axis, model=model, device=device):
# Load and process data in MONAI format
data = {'image': [input_file.name]}
data = preproc_transforms(data)
# Run inference and post-process predicted labels
model.to(device)
model.eval()
with torch.no_grad():
inputs = data['image'].to(device)
data['pred'] = inferer(inputs=inputs[None,...], network=model)
data = post_transforms(data)
# Convert tensors back to numpy arrays
data['image'] = data['image'].numpy()
data['pred'] = data['pred'].cpu().detach().numpy()
# Magnetic resonance imaging sequences
t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast
t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast
t2 = data['image'][2, :, :, z_axis] # T2-weighted
flair = data['image'][3, :, :, z_axis] # FLAIR
# BraTS labels
tc = data['pred'][0, 0, :, :, z_axis] # Tumor core
wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor
et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor
return [t1c, t1, t2, flair], [tc, wt, et]
# Use blocks to set up a more complex demo
with gr.Blocks() as demo:
with gr.Row():
# Get the input file and slice slider as inputs
input_file = gr.File(label='input file')
z_axis = gr.Slider(0, 200, label='slice', value=50)
with gr.Row():
# Show the button with custom label
button = gr.Button("Segment Tumor!")
with gr.Row():
with gr.Column():
# Show the input image with different MR sequences
input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)')
with gr.Column():
# Show the segmentation labels
output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)')
# Run prediction on button click
button.click(
predict,
inputs=[input_file, z_axis],
outputs=[input_image, output_segmentation]
)
# Have some example for the user to try out
examples = gr.Examples(
examples=examples,
inputs=[input_file, z_axis],
outputs=[input_image, output_segmentation],
fn=predict,
cache_examples=False
)
# Launch the demo
demo.launch()