ahmed-7124's picture
Create app.py
f69d338 verified
import gradio as gr
from transformers import pipeline
import pandas as pd
import PyPDF2
import pdfplumber
import torch
import timm
from PIL import Image
# Load pre-trained model for zero-shot classification
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Pre-trained model for X-ray analysis (example with a model from timm library)
image_model = timm.create_model('resnet50', pretrained=True)
image_model.eval()
# Initialize patient database
patients_db = []
# Disease and medication mapping
disease_details = {
"anemia": {
"medication": "Iron supplements (e.g., Ferrous sulfate)",
"precaution": "Increase intake of iron-rich foods like spinach and red meat."
},
"viral infection": {
"medication": "Antiviral drugs (e.g., Oseltamivir for flu)",
"precaution": "Rest, stay hydrated, and avoid close contact with others."
},
"liver disease": {
"medication": "Hepatoprotective drugs (e.g., Ursodeoxycholic acid)",
"precaution": "Avoid alcohol and maintain a balanced diet."
},
"kidney disease": {
"medication": "Angiotensin-converting enzyme inhibitors (e.g., Lisinopril)",
"precaution": "Monitor salt intake and stay hydrated."
},
"diabetes": {
"medication": "Metformin or insulin therapy",
"precaution": "Follow a low-sugar diet and exercise regularly."
},
"hypertension": {
"medication": "Antihypertensive drugs (e.g., Amlodipine)",
"precaution": "Reduce salt intake and manage stress."
},
"COVID-19": {
"medication": "Supportive care, antiviral drugs (e.g., Remdesivir in severe cases)",
"precaution": "Follow isolation protocols, wear a mask, and stay hydrated."
},
"pneumonia": {
"medication": "Antibiotics (e.g., Amoxicillin) if bacterial",
"precaution": "Rest, avoid smoking, and stay hydrated."
}
}
# Function to register patients
def register_patient(name, age, gender):
patient_id = len(patients_db) + 1
patients_db.append({
"ID": patient_id,
"Name": name,
"Age": age,
"Gender": gender,
"Symptoms": "",
"Diagnosis": "",
"Action Plan": "",
"Medications": "",
"Precautions": "",
"Tests": ""
})
return f"βœ… Patient {name} registered successfully. Patient ID: {patient_id}"
# Function to analyze text reports
def analyze_report(patient_id, report_text):
candidate_labels = list(disease_details.keys())
result = classifier(report_text, candidate_labels)
diagnosis = result['labels'][0]
# Fetch medication and precaution
medication = disease_details[diagnosis]["medication"]
precaution = disease_details[diagnosis]["precaution"]
action_plan = f"You might have {diagnosis}. Please consult a doctor for confirmation."
# Store diagnosis in the database
for patient in patients_db:
if patient["ID"] == patient_id:
patient["Diagnosis"] = diagnosis
patient["Action Plan"] = action_plan
patient["Medications"] = medication
patient["Precautions"] = precaution
break
return (f"πŸ” Diagnosis: {diagnosis}\n"
f"🩺 Medications: {medication}\n"
f"⚠️ Precautions: {precaution}\n"
f"πŸ’‘ {action_plan}")
# Function to extract text from PDF reports
def extract_pdf_report(pdf):
text = ""
with pdfplumber.open(pdf.name) as pdf_file:
for page in pdf_file.pages:
text += page.extract_text()
return text
# Function to analyze uploaded images (X-ray/CT-scan)
def analyze_image(patient_id, img):
image = Image.open(img).convert('RGB')
transform = torch.nn.Sequential(
torch.nn.Upsample(size=(224, 224)),
torch.nn.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
image_tensor = transform(torch.unsqueeze(torch.tensor(image), 0))
# Run the image through the model (for simplicity, assuming ResNet50 output)
output = image_model(image_tensor)
_, predicted = torch.max(output, 1)
# Map prediction to a label
labels = {0: "Normal", 1: "Pneumonia", 2: "Liver Disorder", 3: "COVID-19"}
diagnosis = labels.get(predicted.item(), "Unknown")
# Store diagnosis in the database
for patient in patients_db:
if patient["ID"] == patient_id:
patient["Diagnosis"] = diagnosis
break
return f"πŸ” Diagnosis from image: {diagnosis}"
# Function to display the dashboard
def show_dashboard():
if not patients_db:
return "No patient records available."
return pd.DataFrame(patients_db)
# Text-to-Image model interface
def generate_image(prompt):
# This assumes you're using a pre-trained model for text-to-image generation
return gr.load("models/ZB-Tech/Text-to-Image").launch()
# Gradio interface for patient registration
patient_interface = gr.Interface(
fn=register_patient,
inputs=[
gr.Textbox(label="Patient Name", placeholder="Enter the patient's full name"),
gr.Number(label="Age"),
gr.Radio(label="Gender", choices=["Male", "Female", "Other"])
],
outputs="text",
description="Register a new patient"
)
# Gradio interface for report analysis (text input)
report_interface = gr.Interface(
fn=analyze_report,
inputs=[
gr.Number(label="Patient ID"),
gr.Textbox(label="Report Text", placeholder="Paste the text from your report here")
],
outputs="text",
description="Analyze blood, LFT, or other medical reports"
)
# Gradio interface for PDF report analysis (PDF upload)
pdf_report_interface = gr.Interface(
fn=extract_pdf_report,
inputs=gr.File(label="Upload PDF Report"),
outputs="text",
description="Extract and analyze text from PDF reports"
)
# Gradio interface for X-ray/CT-scan image analysis
image_interface = gr.Interface(
fn=analyze_image,
inputs=[
gr.Number(label="Patient ID"),
gr.Image(type="filepath", label="Upload X-ray or CT-Scan Image")
],
outputs="text",
description="Analyze X-ray or CT-scan images for diagnosis"
)
# Gradio interface for the dashboard
dashboard_interface = gr.Interface(
fn=show_dashboard,
inputs=None,
outputs="dataframe",
description="View patient reports and history"
)
# Gradio interface for text-to-image
text_to_image_interface = gr.Interface(
fn=generate_image,
inputs=gr.Textbox(label="Enter a prompt to generate an image"),
outputs="image",
description="Generate images from text prompts"
)
# Organize the layout using Blocks
with gr.Blocks() as demo:
gr.Markdown("# Medical Report and Image Analyzer + Text-to-Image Generator")
choice = gr.Radio(label="Choose Functionality", choices=["Medical Reports", "Text-to-Image"], value="Medical Reports")
with gr.Column():
with gr.TabItem("Patient Registration"):
patient_interface.render()
with gr.TabItem("Analyze Report (Text)"):
report_interface.render()
with gr.TabItem("Analyze Report (PDF)"):
pdf_report_interface.render()
with gr.TabItem("Analyze Image (X-ray/CT)"):
image_interface.render()
with gr.TabItem("Dashboard"):
dashboard_interface.render()
with gr.TabItem("Text-to-Image"):
text_to_image_interface.render()
demo.launch(share=True)