BreastCare / app.py
AliArshad's picture
Update app.py
d7a8a79 verified
raw
history blame
6.94 kB
import gradio as gr
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
pipeline
)
from PIL import Image
import torch
import random
import json
import time
class MedicalImageAnalysisSystem:
def __init__(self):
print("Initializing system...")
# Check for CUDA availability
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Load models one at a time with progress messages
print("Loading tumor classifier...")
self.tumor_classifier_model = AutoModelForImageClassification.from_pretrained(
"SIATCN/vit_tumor_classifier",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.tumor_classifier_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
print("Loading tumor radius detector...")
self.radius_model = AutoModelForImageClassification.from_pretrained(
"SIATCN/vit_tumor_radius_detection_finetuned",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.radius_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_radius_detection_finetuned")
print("Loading language model...")
# Using a smaller model for faster inference
self.llm = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto",
model_kwargs={"low_cpu_mem_usage": True}
)
# Set models to evaluation mode
self.tumor_classifier_model.eval()
self.radius_model.eval()
print("System initialization complete!")
def generate_synthetic_metadata(self):
return {
"age": random.randint(25, 85),
"gender": random.choice(["Male", "Female"]),
"smoking_status": random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]),
"drinking_status": random.choice(["Non-drinker", "Social Drinker", "Regular Drinker"]),
"medications": random.sample([
"Lisinopril", "Metformin", "Levothyroxine", "Amlodipine",
"Metoprolol", "Omeprazole", "Simvastatin", "Losartan"
], random.randint(0, 3))
}
def process_image(self, image):
if isinstance(image, str):
image = Image.open(image)
if image.mode != 'RGB':
image = image.convert('RGB')
return image.resize((224, 224))
@torch.no_grad()
def predict_tumor_presence(self, processed_image):
inputs = self.tumor_classifier_processor(processed_image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.tumor_classifier_model(**inputs)
predictions = torch.softmax(outputs.logits, dim=-1)
probs = predictions[0].cpu().tolist()
# Return just the predicted class instead of probabilities
return "tumor" if probs[1] > probs[0] else "non-tumor"
@torch.no_grad()
def predict_tumor_radius(self, processed_image):
inputs = self.radius_processor(processed_image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.radius_model(**inputs)
predictions = outputs.logits.softmax(dim=-1)
predicted_label = predictions.argmax().item()
class_names = ["no-tumor", "0.5", "1.0", "1.5"]
# Return just the radius without confidence
return class_names[predicted_label]
def generate_llm_interpretation(self, tumor_presence, tumor_radius, metadata):
prompt = f"""<|system|>You are a medical AI assistant. Provide a clear and concise medical interpretation.</s>
<|user|>Analyze the following medical findings:
Image Analysis:
- Tumor Detection: {tumor_presence}
- Tumor Size: {tumor_radius} cm
Patient Profile:
- Age: {metadata['age']} years
- Gender: {metadata['gender']}
- Smoking: {metadata['smoking_status']}
- Alcohol: {metadata['drinking_status']}
- Current Medications: {', '.join(metadata['medications']) if metadata['medications'] else 'None'}
Provide a brief:
1. Key findings
2. Clinical recommendations
3. Follow-up plan</s>
<|assistant|>"""
response = self.llm(
prompt,
max_new_tokens=300,
temperature=0.7,
do_sample=True,
top_p=0.95,
num_return_sequences=1
)
return response[0]['generated_text'].split("<|assistant|>")[-1].strip()
def analyze_image(self, image):
try:
# Process image and generate metadata
processed_image = self.process_image(image)
metadata = self.generate_synthetic_metadata()
# Get predictions
tumor_presence = self.predict_tumor_presence(processed_image)
tumor_radius = self.predict_tumor_radius(processed_image)
# Generate interpretation
interpretation = self.generate_llm_interpretation(
tumor_presence,
tumor_radius,
metadata
)
# Format results
results = {
"metadata": metadata,
"tumor_presence": tumor_presence,
"tumor_radius": tumor_radius,
"interpretation": interpretation
}
return self.format_results(results)
except Exception as e:
return f"Error: {str(e)}"
def format_results(self, results):
return f"""
Patient Information:
{json.dumps(results['metadata'], indent=2)}
Image Analysis Results:
- Tumor Detection: {results['tumor_presence']}
- Tumor Size: {results['tumor_radius']} cm
Medical Assessment:
{results['interpretation']}
"""
def create_interface():
system = MedicalImageAnalysisSystem()
iface = gr.Interface(
fn=system.analyze_image,
inputs=[
gr.Image(type="pil", label="Upload Medical Image")
],
outputs=[
gr.Textbox(label="Analysis Results", lines=20)
],
title="Medical Image Analysis System",
description="Upload a medical image for tumor analysis and recommendations.",
theme=gr.themes.Base()
)
return iface
if __name__ == "__main__":
print("Starting application...")
iface = create_interface()
iface.queue()
iface.launch(debug=True, share=True)