File size: 6,943 Bytes
c08ced7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a8a79
 
 
 
c08ced7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a8a79
c08ced7
 
d7a8a79
c08ced7
 
d7a8a79
 
 
c08ced7
d7a8a79
c08ced7
 
d7a8a79
c08ced7
 
 
 
d7a8a79
 
c08ced7
 
d7a8a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c08ced7
 
 
 
d7a8a79
c08ced7
 
 
 
 
 
 
 
 
 
d7a8a79
c08ced7
 
 
d7a8a79
c08ced7
 
 
d7a8a79
c08ced7
 
 
 
 
 
d7a8a79
 
c08ced7
 
 
 
 
 
d7a8a79
c08ced7
 
d7a8a79
c08ced7
 
 
d7a8a79
c08ced7
 
d7a8a79
 
 
c08ced7
d7a8a79
c08ced7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a8a79
c08ced7
 
 
 
 
 
 
d7a8a79
c08ced7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
from transformers import (
    AutoImageProcessor, 
    AutoModelForImageClassification,
    pipeline
)
from PIL import Image
import torch
import random
import json
import time

class MedicalImageAnalysisSystem:
    def __init__(self):
        print("Initializing system...")
        
        # Check for CUDA availability
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        # Load models one at a time with progress messages
        print("Loading tumor classifier...")
        self.tumor_classifier_model = AutoModelForImageClassification.from_pretrained(
            "SIATCN/vit_tumor_classifier",
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device)
        self.tumor_classifier_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
        
        print("Loading tumor radius detector...")
        self.radius_model = AutoModelForImageClassification.from_pretrained(
            "SIATCN/vit_tumor_radius_detection_finetuned",
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device)
        self.radius_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_radius_detection_finetuned")
        
        print("Loading language model...")
        # Using a smaller model for faster inference
        self.llm = pipeline(
            "text-generation",
            model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            device_map="auto",
            model_kwargs={"low_cpu_mem_usage": True}
        )
        
        # Set models to evaluation mode
        self.tumor_classifier_model.eval()
        self.radius_model.eval()
        
        print("System initialization complete!")

    def generate_synthetic_metadata(self):
        return {
            "age": random.randint(25, 85),
            "gender": random.choice(["Male", "Female"]),
            "smoking_status": random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]),
            "drinking_status": random.choice(["Non-drinker", "Social Drinker", "Regular Drinker"]),
            "medications": random.sample([
                "Lisinopril", "Metformin", "Levothyroxine", "Amlodipine",
                "Metoprolol", "Omeprazole", "Simvastatin", "Losartan"
            ], random.randint(0, 3))
        }

    def process_image(self, image):
        if isinstance(image, str):
            image = Image.open(image)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        return image.resize((224, 224))

    @torch.no_grad()
    def predict_tumor_presence(self, processed_image):
        inputs = self.tumor_classifier_processor(processed_image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        outputs = self.tumor_classifier_model(**inputs)
        predictions = torch.softmax(outputs.logits, dim=-1)
        probs = predictions[0].cpu().tolist()
        # Return just the predicted class instead of probabilities
        return "tumor" if probs[1] > probs[0] else "non-tumor"

    @torch.no_grad()
    def predict_tumor_radius(self, processed_image):
        inputs = self.radius_processor(processed_image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        outputs = self.radius_model(**inputs)
        predictions = outputs.logits.softmax(dim=-1)
        predicted_label = predictions.argmax().item()
        class_names = ["no-tumor", "0.5", "1.0", "1.5"]
        # Return just the radius without confidence
        return class_names[predicted_label]

    def generate_llm_interpretation(self, tumor_presence, tumor_radius, metadata):
        prompt = f"""<|system|>You are a medical AI assistant. Provide a clear and concise medical interpretation.</s>
        <|user|>Analyze the following medical findings:

        Image Analysis:
        - Tumor Detection: {tumor_presence}
        - Tumor Size: {tumor_radius} cm

        Patient Profile:
        - Age: {metadata['age']} years
        - Gender: {metadata['gender']}
        - Smoking: {metadata['smoking_status']}
        - Alcohol: {metadata['drinking_status']}
        - Current Medications: {', '.join(metadata['medications']) if metadata['medications'] else 'None'}

        Provide a brief:
        1. Key findings
        2. Clinical recommendations
        3. Follow-up plan</s>
        <|assistant|>"""

        response = self.llm(
            prompt,
            max_new_tokens=300,
            temperature=0.7,
            do_sample=True,
            top_p=0.95,
            num_return_sequences=1
        )
        
        return response[0]['generated_text'].split("<|assistant|>")[-1].strip()

    def analyze_image(self, image):
        try:
            # Process image and generate metadata
            processed_image = self.process_image(image)
            metadata = self.generate_synthetic_metadata()
            
            # Get predictions
            tumor_presence = self.predict_tumor_presence(processed_image)
            tumor_radius = self.predict_tumor_radius(processed_image)
            
            # Generate interpretation
            interpretation = self.generate_llm_interpretation(
                tumor_presence,
                tumor_radius,
                metadata
            )
            
            # Format results
            results = {
                "metadata": metadata,
                "tumor_presence": tumor_presence,
                "tumor_radius": tumor_radius,
                "interpretation": interpretation
            }
            
            return self.format_results(results)
            
        except Exception as e:
            return f"Error: {str(e)}"

    def format_results(self, results):
        return f"""
        Patient Information:
        {json.dumps(results['metadata'], indent=2)}
        
        Image Analysis Results:
        - Tumor Detection: {results['tumor_presence']}
        - Tumor Size: {results['tumor_radius']} cm
        
        Medical Assessment:
        {results['interpretation']}
        """

def create_interface():
    system = MedicalImageAnalysisSystem()
    
    iface = gr.Interface(
        fn=system.analyze_image,
        inputs=[
            gr.Image(type="pil", label="Upload Medical Image")
        ],
        outputs=[
            gr.Textbox(label="Analysis Results", lines=20)
        ],
        title="Medical Image Analysis System",
        description="Upload a medical image for tumor analysis and recommendations.",
        theme=gr.themes.Base()
    )
    
    return iface

if __name__ == "__main__":
    print("Starting application...")
    iface = create_interface()
    iface.queue()
    iface.launch(debug=True, share=True)