ahmed-7124
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
import pandas as pd
|
4 |
+
import PyPDF2
|
5 |
+
import pdfplumber
|
6 |
+
import torch
|
7 |
+
import timm
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
# Load pre-trained model for zero-shot classification
|
11 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
12 |
+
|
13 |
+
# Pre-trained model for X-ray analysis (example with a model from timm library)
|
14 |
+
image_model = timm.create_model('resnet50', pretrained=True)
|
15 |
+
image_model.eval()
|
16 |
+
|
17 |
+
# Initialize patient database
|
18 |
+
patients_db = []
|
19 |
+
|
20 |
+
# Disease and medication mapping
|
21 |
+
disease_details = {
|
22 |
+
"anemia": {
|
23 |
+
"medication": "Iron supplements (e.g., Ferrous sulfate)",
|
24 |
+
"precaution": "Increase intake of iron-rich foods like spinach and red meat."
|
25 |
+
},
|
26 |
+
"viral infection": {
|
27 |
+
"medication": "Antiviral drugs (e.g., Oseltamivir for flu)",
|
28 |
+
"precaution": "Rest, stay hydrated, and avoid close contact with others."
|
29 |
+
},
|
30 |
+
"liver disease": {
|
31 |
+
"medication": "Hepatoprotective drugs (e.g., Ursodeoxycholic acid)",
|
32 |
+
"precaution": "Avoid alcohol and maintain a balanced diet."
|
33 |
+
},
|
34 |
+
"kidney disease": {
|
35 |
+
"medication": "Angiotensin-converting enzyme inhibitors (e.g., Lisinopril)",
|
36 |
+
"precaution": "Monitor salt intake and stay hydrated."
|
37 |
+
},
|
38 |
+
"diabetes": {
|
39 |
+
"medication": "Metformin or insulin therapy",
|
40 |
+
"precaution": "Follow a low-sugar diet and exercise regularly."
|
41 |
+
},
|
42 |
+
"hypertension": {
|
43 |
+
"medication": "Antihypertensive drugs (e.g., Amlodipine)",
|
44 |
+
"precaution": "Reduce salt intake and manage stress."
|
45 |
+
},
|
46 |
+
"COVID-19": {
|
47 |
+
"medication": "Supportive care, antiviral drugs (e.g., Remdesivir in severe cases)",
|
48 |
+
"precaution": "Follow isolation protocols, wear a mask, and stay hydrated."
|
49 |
+
},
|
50 |
+
"pneumonia": {
|
51 |
+
"medication": "Antibiotics (e.g., Amoxicillin) if bacterial",
|
52 |
+
"precaution": "Rest, avoid smoking, and stay hydrated."
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
# Function to register patients
|
57 |
+
def register_patient(name, age, gender):
|
58 |
+
patient_id = len(patients_db) + 1
|
59 |
+
patients_db.append({
|
60 |
+
"ID": patient_id,
|
61 |
+
"Name": name,
|
62 |
+
"Age": age,
|
63 |
+
"Gender": gender,
|
64 |
+
"Symptoms": "",
|
65 |
+
"Diagnosis": "",
|
66 |
+
"Action Plan": "",
|
67 |
+
"Medications": "",
|
68 |
+
"Precautions": "",
|
69 |
+
"Tests": ""
|
70 |
+
})
|
71 |
+
return f"✅ Patient {name} registered successfully. Patient ID: {patient_id}"
|
72 |
+
|
73 |
+
# Function to analyze text reports
|
74 |
+
def analyze_report(patient_id, report_text):
|
75 |
+
candidate_labels = list(disease_details.keys())
|
76 |
+
result = classifier(report_text, candidate_labels)
|
77 |
+
diagnosis = result['labels'][0]
|
78 |
+
|
79 |
+
# Fetch medication and precaution
|
80 |
+
medication = disease_details[diagnosis]["medication"]
|
81 |
+
precaution = disease_details[diagnosis]["precaution"]
|
82 |
+
action_plan = f"You might have {diagnosis}. Please consult a doctor for confirmation."
|
83 |
+
|
84 |
+
# Store diagnosis in the database
|
85 |
+
for patient in patients_db:
|
86 |
+
if patient["ID"] == patient_id:
|
87 |
+
patient["Diagnosis"] = diagnosis
|
88 |
+
patient["Action Plan"] = action_plan
|
89 |
+
patient["Medications"] = medication
|
90 |
+
patient["Precautions"] = precaution
|
91 |
+
break
|
92 |
+
|
93 |
+
return (f"🔍 Diagnosis: {diagnosis}\n"
|
94 |
+
f"🩺 Medications: {medication}\n"
|
95 |
+
f"⚠️ Precautions: {precaution}\n"
|
96 |
+
f"💡 {action_plan}")
|
97 |
+
|
98 |
+
# Function to extract text from PDF reports
|
99 |
+
def extract_pdf_report(pdf):
|
100 |
+
text = ""
|
101 |
+
with pdfplumber.open(pdf.name) as pdf_file:
|
102 |
+
for page in pdf_file.pages:
|
103 |
+
text += page.extract_text()
|
104 |
+
return text
|
105 |
+
|
106 |
+
# Function to analyze uploaded images (X-ray/CT-scan)
|
107 |
+
def analyze_image(patient_id, img):
|
108 |
+
image = Image.open(img).convert('RGB')
|
109 |
+
transform = torch.nn.Sequential(
|
110 |
+
torch.nn.Upsample(size=(224, 224)),
|
111 |
+
torch.nn.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
112 |
+
)
|
113 |
+
image_tensor = transform(torch.unsqueeze(torch.tensor(image), 0))
|
114 |
+
|
115 |
+
# Run the image through the model (for simplicity, assuming ResNet50 output)
|
116 |
+
output = image_model(image_tensor)
|
117 |
+
_, predicted = torch.max(output, 1)
|
118 |
+
|
119 |
+
# Map prediction to a label
|
120 |
+
labels = {0: "Normal", 1: "Pneumonia", 2: "Liver Disorder", 3: "COVID-19"}
|
121 |
+
diagnosis = labels.get(predicted.item(), "Unknown")
|
122 |
+
|
123 |
+
# Store diagnosis in the database
|
124 |
+
for patient in patients_db:
|
125 |
+
if patient["ID"] == patient_id:
|
126 |
+
patient["Diagnosis"] = diagnosis
|
127 |
+
break
|
128 |
+
|
129 |
+
return f"🔍 Diagnosis from image: {diagnosis}"
|
130 |
+
|
131 |
+
# Function to display the dashboard
|
132 |
+
def show_dashboard():
|
133 |
+
if not patients_db:
|
134 |
+
return "No patient records available."
|
135 |
+
return pd.DataFrame(patients_db)
|
136 |
+
|
137 |
+
# Text-to-Image model interface
|
138 |
+
def generate_image(prompt):
|
139 |
+
# This assumes you're using a pre-trained model for text-to-image generation
|
140 |
+
return gr.load("models/ZB-Tech/Text-to-Image").launch()
|
141 |
+
|
142 |
+
# Gradio interface for patient registration
|
143 |
+
patient_interface = gr.Interface(
|
144 |
+
fn=register_patient,
|
145 |
+
inputs=[
|
146 |
+
gr.Textbox(label="Patient Name", placeholder="Enter the patient's full name"),
|
147 |
+
gr.Number(label="Age"),
|
148 |
+
gr.Radio(label="Gender", choices=["Male", "Female", "Other"])
|
149 |
+
],
|
150 |
+
outputs="text",
|
151 |
+
description="Register a new patient"
|
152 |
+
)
|
153 |
+
|
154 |
+
# Gradio interface for report analysis (text input)
|
155 |
+
report_interface = gr.Interface(
|
156 |
+
fn=analyze_report,
|
157 |
+
inputs=[
|
158 |
+
gr.Number(label="Patient ID"),
|
159 |
+
gr.Textbox(label="Report Text", placeholder="Paste the text from your report here")
|
160 |
+
],
|
161 |
+
outputs="text",
|
162 |
+
description="Analyze blood, LFT, or other medical reports"
|
163 |
+
)
|
164 |
+
|
165 |
+
# Gradio interface for PDF report analysis (PDF upload)
|
166 |
+
pdf_report_interface = gr.Interface(
|
167 |
+
fn=extract_pdf_report,
|
168 |
+
inputs=gr.File(label="Upload PDF Report"),
|
169 |
+
outputs="text",
|
170 |
+
description="Extract and analyze text from PDF reports"
|
171 |
+
)
|
172 |
+
|
173 |
+
# Gradio interface for X-ray/CT-scan image analysis
|
174 |
+
image_interface = gr.Interface(
|
175 |
+
fn=analyze_image,
|
176 |
+
inputs=[
|
177 |
+
gr.Number(label="Patient ID"),
|
178 |
+
gr.Image(type="filepath", label="Upload X-ray or CT-Scan Image")
|
179 |
+
],
|
180 |
+
outputs="text",
|
181 |
+
description="Analyze X-ray or CT-scan images for diagnosis"
|
182 |
+
)
|
183 |
+
|
184 |
+
# Gradio interface for the dashboard
|
185 |
+
dashboard_interface = gr.Interface(
|
186 |
+
fn=show_dashboard,
|
187 |
+
inputs=None,
|
188 |
+
outputs="dataframe",
|
189 |
+
description="View patient reports and history"
|
190 |
+
)
|
191 |
+
|
192 |
+
# Gradio interface for text-to-image
|
193 |
+
text_to_image_interface = gr.Interface(
|
194 |
+
fn=generate_image,
|
195 |
+
inputs=gr.Textbox(label="Enter a prompt to generate an image"),
|
196 |
+
outputs="image",
|
197 |
+
description="Generate images from text prompts"
|
198 |
+
)
|
199 |
+
|
200 |
+
# Organize the layout using Blocks
|
201 |
+
with gr.Blocks() as demo:
|
202 |
+
gr.Markdown("# Medical Report and Image Analyzer + Text-to-Image Generator")
|
203 |
+
choice = gr.Radio(label="Choose Functionality", choices=["Medical Reports", "Text-to-Image"], value="Medical Reports")
|
204 |
+
|
205 |
+
with gr.Column():
|
206 |
+
with gr.TabItem("Patient Registration"):
|
207 |
+
patient_interface.render()
|
208 |
+
with gr.TabItem("Analyze Report (Text)"):
|
209 |
+
report_interface.render()
|
210 |
+
with gr.TabItem("Analyze Report (PDF)"):
|
211 |
+
pdf_report_interface.render()
|
212 |
+
with gr.TabItem("Analyze Image (X-ray/CT)"):
|
213 |
+
image_interface.render()
|
214 |
+
with gr.TabItem("Dashboard"):
|
215 |
+
dashboard_interface.render()
|
216 |
+
with gr.TabItem("Text-to-Image"):
|
217 |
+
text_to_image_interface.render()
|
218 |
+
|
219 |
+
demo.launch(share=True)
|