AliArshad commited on
Commit
c08ced7
·
verified ·
1 Parent(s): cdc5112

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoImageProcessor,
4
+ AutoModelForImageClassification,
5
+ pipeline
6
+ )
7
+ from PIL import Image
8
+ import torch
9
+ import random
10
+ import json
11
+ import time
12
+
13
+ class MedicalImageAnalysisSystem:
14
+ def __init__(self):
15
+ print("Initializing system...")
16
+
17
+ # Check for CUDA availability
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Using device: {self.device}")
20
+
21
+ # Load models one at a time with progress messages
22
+ print("Loading tumor classifier...")
23
+ self.tumor_classifier_model = AutoModelForImageClassification.from_pretrained(
24
+ "SIATCN/vit_tumor_classifier",
25
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
26
+ ).to(self.device)
27
+ self.tumor_classifier_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
28
+
29
+ print("Loading tumor radius detector...")
30
+ self.radius_model = AutoModelForImageClassification.from_pretrained(
31
+ "SIATCN/vit_tumor_radius_detection_finetuned",
32
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
33
+ ).to(self.device)
34
+ self.radius_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_radius_detection_finetuned")
35
+
36
+ print("Loading language model...")
37
+ # Using a smaller model for faster inference
38
+ self.llm = pipeline(
39
+ "text-generation",
40
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
41
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
42
+ device_map="auto",
43
+ model_kwargs={"low_cpu_mem_usage": True}
44
+ )
45
+
46
+ print("System initialization complete!")
47
+
48
+ def generate_synthetic_metadata(self):
49
+ return {
50
+ "age": random.randint(25, 85),
51
+ "gender": random.choice(["Male", "Female"]),
52
+ "smoking_status": random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]),
53
+ "drinking_status": random.choice(["Non-drinker", "Social Drinker", "Regular Drinker"]),
54
+ "medications": random.sample([
55
+ "Lisinopril", "Metformin", "Levothyroxine", "Amlodipine",
56
+ "Metoprolol", "Omeprazole", "Simvastatin", "Losartan"
57
+ ], random.randint(0, 3))
58
+ }
59
+
60
+ def process_image(self, image):
61
+ if isinstance(image, str):
62
+ image = Image.open(image)
63
+ if image.mode != 'RGB':
64
+ image = image.convert('RGB')
65
+ return image.resize((224, 224))
66
+
67
+ @torch.no_grad() # Disable gradient computation for inference
68
+ def predict_tumor_presence(self, processed_image):
69
+ inputs = self.tumor_classifier_processor(processed_image, return_tensors="pt")
70
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to correct device
71
+ outputs = self.tumor_classifier_model(**inputs)
72
+ predictions = torch.softmax(outputs.logits, dim=-1)
73
+ probs = predictions[0].cpu().tolist() # Move back to CPU for numpy operations
74
+ return {
75
+ "non-tumor": float(probs[0]),
76
+ "tumor": float(probs[1])
77
+ }
78
+
79
+ @torch.no_grad() # Disable gradient computation for inference
80
+ def predict_tumor_radius(self, processed_image):
81
+ inputs = self.radius_processor(processed_image, return_tensors="pt")
82
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to correct device
83
+ outputs = self.radius_model(**inputs)
84
+ predictions = outputs.logits.softmax(dim=-1)
85
+ predicted_label = predictions.argmax().item()
86
+ confidence = predictions[0][predicted_label].cpu().item() # Move back to CPU
87
+
88
+ class_names = ["no-tumor", "0.5", "1.0", "1.5"]
89
+ return {
90
+ "radius": class_names[predicted_label],
91
+ "confidence": float(confidence)
92
+ }
93
+
94
+ def generate_llm_interpretation(self, tumor_presence, tumor_radius, metadata):
95
+ prompt = f"""<|system|>You are a medical AI assistant. Be concise but thorough.</s>
96
+ <|user|>Analyze these results:
97
+ Tumor Detection: {json.dumps(tumor_presence)}
98
+ Tumor Radius: {json.dumps(tumor_radius)}
99
+ Patient: {metadata['age']}y/o {metadata['gender']}, {metadata['smoking_status']}, {metadata['drinking_status']}
100
+ Medications: {', '.join(metadata['medications']) if metadata['medications'] else 'None'}
101
+ Provide: 1. Key findings 2. Risk assessment 3. Recommendations</s>
102
+ <|assistant|>"""
103
+
104
+ response = self.llm(
105
+ prompt,
106
+ max_new_tokens=300, # Reduced for faster response
107
+ temperature=0.7,
108
+ do_sample=True,
109
+ top_p=0.95,
110
+ num_return_sequences=1
111
+ )
112
+
113
+ return response[0]['generated_text'].split("<|assistant|>")[-1].strip()
114
+
115
+ def analyze_image(self, image):
116
+ try:
117
+ # Add progress updates
118
+ yield "Processing image..."
119
+ processed_image = self.process_image(image)
120
+
121
+ yield "Generating patient metadata..."
122
+ metadata = self.generate_synthetic_metadata()
123
+
124
+ yield "Analyzing tumor presence..."
125
+ tumor_presence = self.predict_tumor_presence(processed_image)
126
+
127
+ yield "Analyzing tumor radius..."
128
+ tumor_radius = self.predict_tumor_radius(processed_image)
129
+
130
+ yield "Generating medical interpretation..."
131
+ interpretation = self.generate_llm_interpretation(
132
+ tumor_presence,
133
+ tumor_radius,
134
+ metadata
135
+ )
136
+
137
+ # Final results
138
+ result = {
139
+ "metadata": metadata,
140
+ "tumor_presence": tumor_presence,
141
+ "tumor_radius": tumor_radius,
142
+ "interpretation": interpretation
143
+ }
144
+
145
+ yield self.format_results(result)
146
+
147
+ except Exception as e:
148
+ yield f"Error: {str(e)}"
149
+
150
+ def format_results(self, results):
151
+ return f"""
152
+ Patient Metadata:
153
+ {json.dumps(results['metadata'], indent=2)}
154
+
155
+ Tumor Presence Analysis:
156
+ {json.dumps(results['tumor_presence'], indent=2)}
157
+
158
+ Tumor Radius Analysis:
159
+ {json.dumps(results['tumor_radius'], indent=2)}
160
+
161
+ Medical Interpretation and Recommendations:
162
+ {results['interpretation']}
163
+ """
164
+
165
+ def create_interface():
166
+ system = MedicalImageAnalysisSystem()
167
+
168
+ iface = gr.Interface(
169
+ fn=system.analyze_image,
170
+ inputs=[
171
+ gr.Image(type="pil", label="Upload Medical Image")
172
+ ],
173
+ outputs=[
174
+ gr.Textbox(label="Analysis Results", lines=20)
175
+ ],
176
+ title="Medical Image Analysis System",
177
+ description="Upload a medical image for tumor analysis and recommendations.",
178
+ theme=gr.themes.Base(),
179
+ flagging=False
180
+ )
181
+
182
+ return iface
183
+
184
+ if __name__ == "__main__":
185
+ print("Starting application...")
186
+ iface = create_interface()
187
+ iface.queue() # Enable queuing for better handling of multiple requests
188
+ iface.launch(debug=True, share=True)