|
import gradio as gr |
|
from transformers import pipeline |
|
import pandas as pd |
|
import PyPDF2 |
|
import pdfplumber |
|
import torch |
|
import timm |
|
from PIL import Image |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
|
|
image_model = timm.create_model('resnet50', pretrained=True) |
|
image_model.eval() |
|
|
|
|
|
patients_db = [] |
|
|
|
|
|
disease_details = { |
|
"anemia": { |
|
"medication": "Iron supplements (e.g., Ferrous sulfate)", |
|
"precaution": "Increase intake of iron-rich foods like spinach and red meat." |
|
}, |
|
"viral infection": { |
|
"medication": "Antiviral drugs (e.g., Oseltamivir for flu)", |
|
"precaution": "Rest, stay hydrated, and avoid close contact with others." |
|
}, |
|
"liver disease": { |
|
"medication": "Hepatoprotective drugs (e.g., Ursodeoxycholic acid)", |
|
"precaution": "Avoid alcohol and maintain a balanced diet." |
|
}, |
|
"kidney disease": { |
|
"medication": "Angiotensin-converting enzyme inhibitors (e.g., Lisinopril)", |
|
"precaution": "Monitor salt intake and stay hydrated." |
|
}, |
|
"diabetes": { |
|
"medication": "Metformin or insulin therapy", |
|
"precaution": "Follow a low-sugar diet and exercise regularly." |
|
}, |
|
"hypertension": { |
|
"medication": "Antihypertensive drugs (e.g., Amlodipine)", |
|
"precaution": "Reduce salt intake and manage stress." |
|
}, |
|
"COVID-19": { |
|
"medication": "Supportive care, antiviral drugs (e.g., Remdesivir in severe cases)", |
|
"precaution": "Follow isolation protocols, wear a mask, and stay hydrated." |
|
}, |
|
"pneumonia": { |
|
"medication": "Antibiotics (e.g., Amoxicillin) if bacterial", |
|
"precaution": "Rest, avoid smoking, and stay hydrated." |
|
} |
|
} |
|
|
|
|
|
def register_patient(name, age, gender): |
|
patient_id = len(patients_db) + 1 |
|
patients_db.append({ |
|
"ID": patient_id, |
|
"Name": name, |
|
"Age": age, |
|
"Gender": gender, |
|
"Symptoms": "", |
|
"Diagnosis": "", |
|
"Action Plan": "", |
|
"Medications": "", |
|
"Precautions": "", |
|
"Tests": "" |
|
}) |
|
return f"β
Patient {name} registered successfully. Patient ID: {patient_id}" |
|
|
|
|
|
def analyze_report(patient_id, report_text): |
|
candidate_labels = list(disease_details.keys()) |
|
result = classifier(report_text, candidate_labels) |
|
diagnosis = result['labels'][0] |
|
|
|
|
|
medication = disease_details[diagnosis]["medication"] |
|
precaution = disease_details[diagnosis]["precaution"] |
|
action_plan = f"You might have {diagnosis}. Please consult a doctor for confirmation." |
|
|
|
|
|
for patient in patients_db: |
|
if patient["ID"] == patient_id: |
|
patient["Diagnosis"] = diagnosis |
|
patient["Action Plan"] = action_plan |
|
patient["Medications"] = medication |
|
patient["Precautions"] = precaution |
|
break |
|
|
|
return (f"π Diagnosis: {diagnosis}\n" |
|
f"π©Ί Medications: {medication}\n" |
|
f"β οΈ Precautions: {precaution}\n" |
|
f"π‘ {action_plan}") |
|
|
|
|
|
def extract_pdf_report(pdf): |
|
text = "" |
|
with pdfplumber.open(pdf.name) as pdf_file: |
|
for page in pdf_file.pages: |
|
text += page.extract_text() |
|
return text |
|
|
|
|
|
def analyze_image(patient_id, img): |
|
image = Image.open(img).convert('RGB') |
|
transform = torch.nn.Sequential( |
|
torch.nn.Upsample(size=(224, 224)), |
|
torch.nn.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
) |
|
image_tensor = transform(torch.unsqueeze(torch.tensor(image), 0)) |
|
|
|
|
|
output = image_model(image_tensor) |
|
_, predicted = torch.max(output, 1) |
|
|
|
|
|
labels = {0: "Normal", 1: "Pneumonia", 2: "Liver Disorder", 3: "COVID-19"} |
|
diagnosis = labels.get(predicted.item(), "Unknown") |
|
|
|
|
|
for patient in patients_db: |
|
if patient["ID"] == patient_id: |
|
patient["Diagnosis"] = diagnosis |
|
break |
|
|
|
return f"π Diagnosis from image: {diagnosis}" |
|
|
|
|
|
def show_dashboard(): |
|
if not patients_db: |
|
return "No patient records available." |
|
return pd.DataFrame(patients_db) |
|
|
|
|
|
def generate_image(prompt): |
|
|
|
return gr.load("models/ZB-Tech/Text-to-Image").launch() |
|
|
|
|
|
patient_interface = gr.Interface( |
|
fn=register_patient, |
|
inputs=[ |
|
gr.Textbox(label="Patient Name", placeholder="Enter the patient's full name"), |
|
gr.Number(label="Age"), |
|
gr.Radio(label="Gender", choices=["Male", "Female", "Other"]) |
|
], |
|
outputs="text", |
|
description="Register a new patient" |
|
) |
|
|
|
|
|
report_interface = gr.Interface( |
|
fn=analyze_report, |
|
inputs=[ |
|
gr.Number(label="Patient ID"), |
|
gr.Textbox(label="Report Text", placeholder="Paste the text from your report here") |
|
], |
|
outputs="text", |
|
description="Analyze blood, LFT, or other medical reports" |
|
) |
|
|
|
|
|
pdf_report_interface = gr.Interface( |
|
fn=extract_pdf_report, |
|
inputs=gr.File(label="Upload PDF Report"), |
|
outputs="text", |
|
description="Extract and analyze text from PDF reports" |
|
) |
|
|
|
|
|
image_interface = gr.Interface( |
|
fn=analyze_image, |
|
inputs=[ |
|
gr.Number(label="Patient ID"), |
|
gr.Image(type="filepath", label="Upload X-ray or CT-Scan Image") |
|
], |
|
outputs="text", |
|
description="Analyze X-ray or CT-scan images for diagnosis" |
|
) |
|
|
|
|
|
dashboard_interface = gr.Interface( |
|
fn=show_dashboard, |
|
inputs=None, |
|
outputs="dataframe", |
|
description="View patient reports and history" |
|
) |
|
|
|
|
|
text_to_image_interface = gr.Interface( |
|
fn=generate_image, |
|
inputs=gr.Textbox(label="Enter a prompt to generate an image"), |
|
outputs="image", |
|
description="Generate images from text prompts" |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Medical Report and Image Analyzer + Text-to-Image Generator") |
|
choice = gr.Radio(label="Choose Functionality", choices=["Medical Reports", "Text-to-Image"], value="Medical Reports") |
|
|
|
with gr.Column(): |
|
with gr.TabItem("Patient Registration"): |
|
patient_interface.render() |
|
with gr.TabItem("Analyze Report (Text)"): |
|
report_interface.render() |
|
with gr.TabItem("Analyze Report (PDF)"): |
|
pdf_report_interface.render() |
|
with gr.TabItem("Analyze Image (X-ray/CT)"): |
|
image_interface.render() |
|
with gr.TabItem("Dashboard"): |
|
dashboard_interface.render() |
|
with gr.TabItem("Text-to-Image"): |
|
text_to_image_interface.render() |
|
|
|
demo.launch(share=True) |
|
|