AliArshad commited on
Commit
ac8852f
·
verified ·
1 Parent(s): 3fcb92e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+ from typing import Tuple, Optional, Dict, Any
6
+ from dataclasses import dataclass
7
+ import random
8
+ from datetime import datetime, timedelta
9
+
10
+ @dataclass
11
+ class PatientMetadata:
12
+ age: int
13
+ smoking_status: str
14
+ family_history: bool
15
+ menopause_status: str
16
+ previous_mammogram: bool
17
+ breast_density: str
18
+ hormone_therapy: bool
19
+
20
+ @dataclass
21
+ class AnalysisResult:
22
+ has_tumor: bool
23
+ tumor_size: str
24
+ confidence: float
25
+ metadata: PatientMetadata
26
+
27
+ class BreastSinogramAnalyzer:
28
+ def __init__(self):
29
+ """Initialize the analyzer with required models."""
30
+ print("Initializing system...")
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ print(f"Using device: {self.device}")
33
+
34
+ self._init_vision_models()
35
+ self._init_llm()
36
+ print("Initialization complete!")
37
+
38
+ def _init_vision_models(self) -> None:
39
+ """Initialize vision models for abnormality detection and size measurement."""
40
+ print("Loading detection models...")
41
+ self.tumor_detector = AutoModelForImageClassification.from_pretrained(
42
+ "SIATCN/vit_tumor_classifier"
43
+ ).to(self.device).eval()
44
+ self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
45
+
46
+ self.size_detector = AutoModelForImageClassification.from_pretrained(
47
+ "SIATCN/vit_tumor_radius_detection_finetuned"
48
+ ).to(self.device).eval()
49
+ self.size_processor = AutoImageProcessor.from_pretrained(
50
+ "SIATCN/vit_tumor_radius_detection_finetuned"
51
+ )
52
+
53
+ def _init_llm(self) -> None:
54
+ """Initialize the language model for report generation."""
55
+ print("Loading language model pipeline...")
56
+ self.pipe = pipeline(
57
+ "text-generation",
58
+ model="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
59
+ torch_dtype=torch.float16,
60
+ device_map="auto",
61
+ model_kwargs={
62
+ "load_in_4bit": True,
63
+ "bnb_4bit_compute_dtype": torch.float16,
64
+ }
65
+ )
66
+
67
+ def _generate_synthetic_metadata(self) -> PatientMetadata:
68
+ """Generate realistic patient metadata for breast cancer screening."""
69
+ age = random.randint(40, 75)
70
+ smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"])
71
+ family_history = random.choice([True, False])
72
+ menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal"
73
+ previous_mammogram = random.choice([True, False])
74
+ breast_density = random.choice(["A: Almost entirely fatty",
75
+ "B: Scattered fibroglandular",
76
+ "C: Heterogeneously dense",
77
+ "D: Extremely dense"])
78
+ hormone_therapy = random.choice([True, False])
79
+
80
+ return PatientMetadata(
81
+ age=age,
82
+ smoking_status=smoking_status,
83
+ family_history=family_history,
84
+ menopause_status=menopause_status,
85
+ previous_mammogram=previous_mammogram,
86
+ breast_density=breast_density,
87
+ hormone_therapy=hormone_therapy
88
+ )
89
+
90
+ def _process_image(self, image: Image.Image) -> Image.Image:
91
+ """Process input image for model consumption."""
92
+ if image.mode != 'RGB':
93
+ image = image.convert('RGB')
94
+ return image.resize((224, 224))
95
+
96
+ @torch.no_grad()
97
+ def _analyze_image(self, image: Image.Image) -> AnalysisResult:
98
+ """Perform abnormality detection and size measurement."""
99
+ # Generate metadata
100
+ metadata = self._generate_synthetic_metadata()
101
+
102
+ # Detect abnormality
103
+ tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device)
104
+ tumor_outputs = self.tumor_detector(**tumor_inputs)
105
+ tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu()
106
+ has_tumor = tumor_probs[1] > tumor_probs[0]
107
+ confidence = float(tumor_probs[1] if has_tumor else tumor_probs[0])
108
+
109
+ # Measure size
110
+ size_inputs = self.size_processor(image, return_tensors="pt").to(self.device)
111
+ size_outputs = self.size_detector(**size_inputs)
112
+ size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
113
+ sizes = ["no-tumor", "0.5", "1.0", "1.5"]
114
+ tumor_size = sizes[size_pred.argmax().item()]
115
+
116
+ return AnalysisResult(has_tumor, tumor_size, confidence, metadata)
117
+
118
+ def _generate_medical_report(self, analysis: AnalysisResult) -> str:
119
+ """Generate a simplified medical report."""
120
+ prompt = f"""<|system|>You are a radiologist providing clear and concise medical reports.</s>
121
+ <|user|>Generate a brief medical report for this microwave breast imaging scan:
122
+
123
+ Findings:
124
+ - {'Abnormal' if analysis.has_tumor else 'Normal'} dielectric properties
125
+ - Size: {analysis.tumor_size} cm
126
+ - Confidence: {analysis.confidence:.2%}
127
+ - Patient age: {analysis.metadata.age}
128
+ - Risk factors: {', '.join([
129
+ 'family history' if analysis.metadata.family_history else '',
130
+ analysis.metadata.smoking_status.lower(),
131
+ 'hormone therapy' if analysis.metadata.hormone_therapy else ''
132
+ ]).strip(', ')}
133
+
134
+ Provide:
135
+ 1. One sentence interpreting the findings
136
+ 2. One clear management recommendation</s>
137
+ <|assistant|>"""
138
+
139
+ try:
140
+ response = self.pipe(
141
+ prompt,
142
+ max_new_tokens=128,
143
+ temperature=0.3,
144
+ top_p=0.9,
145
+ repetition_penalty=1.1,
146
+ do_sample=True,
147
+ num_return_sequences=1
148
+ )[0]["generated_text"]
149
+
150
+ # Extract assistant's response
151
+ if "<|assistant|>" in response:
152
+ report = response.split("<|assistant|>")[-1].strip()
153
+ else:
154
+ report = response[len(prompt):].strip()
155
+
156
+ # Simple validation
157
+ if len(report.split()) >= 10:
158
+ return f"""INTERPRETATION AND RECOMMENDATION:
159
+ {report}"""
160
+
161
+ print("Report too short, using fallback")
162
+ return self._generate_fallback_report(analysis)
163
+
164
+ except Exception as e:
165
+ print(f"Error in report generation: {str(e)}")
166
+ return self._generate_fallback_report(analysis)
167
+
168
+ def _generate_fallback_report(self, analysis: AnalysisResult) -> str:
169
+ """Generate a simple fallback report."""
170
+ if analysis.has_tumor:
171
+ return f"""INTERPRETATION AND RECOMMENDATION:
172
+ Microwave imaging reveals abnormal dielectric properties measuring {analysis.tumor_size} cm with {analysis.confidence:.1%} confidence level.
173
+
174
+ {'Immediate conventional imaging and clinical correlation recommended.' if analysis.tumor_size in ['1.0', '1.5'] else 'Follow-up imaging recommended in 6 months.'}"""
175
+ else:
176
+ return f"""INTERPRETATION AND RECOMMENDATION:
177
+ Microwave imaging shows normal dielectric properties with {analysis.confidence:.1%} confidence level.
178
+
179
+ Routine screening recommended per standard protocol."""
180
+
181
+ def analyze(self, image: Image.Image) -> str:
182
+ """Main analysis pipeline."""
183
+ try:
184
+ processed_image = self._process_image(image)
185
+ analysis = self._analyze_image(processed_image)
186
+ report = self._generate_medical_report(analysis)
187
+
188
+ return f"""MICROWAVE IMAGING ANALYSIS:
189
+ • Detection: {'Positive' if analysis.has_tumor else 'Negative'}
190
+ • Size: {analysis.tumor_size} cm
191
+
192
+
193
+ PATIENT INFO:
194
+ • Age: {analysis.metadata.age} years
195
+ • Risk Factors: {', '.join([
196
+ 'family history' if analysis.metadata.family_history else '',
197
+ analysis.metadata.smoking_status.lower(),
198
+ 'hormone therapy' if analysis.metadata.hormone_therapy else '',
199
+ ]).strip(', ')}
200
+
201
+ REPORT:
202
+ {report}"""
203
+ except Exception as e:
204
+ return f"Error during analysis: {str(e)}"
205
+
206
+ def create_interface() -> gr.Interface:
207
+ """Create the Gradio interface."""
208
+ analyzer = BreastSinogramAnalyzer()
209
+
210
+ interface = gr.Interface(
211
+ fn=analyzer.analyze,
212
+ inputs=[
213
+ gr.Image(type="pil", label="Upload Breast Microwave Image")
214
+ ],
215
+ outputs=[
216
+ gr.Textbox(label="Analysis Results", lines=20)
217
+ ],
218
+ title="Breast Cancer Microwave Imaging Analysis System",
219
+ description="Upload a breast microwave image for comprehensive analysis and medical assessment.",
220
+ )
221
+
222
+ return interface
223
+
224
+ if __name__ == "__main__":
225
+ print("Starting application...")
226
+ interface = create_interface()
227
+ interface.launch(debug=True, share=True)