Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import torch | |
from typing import Tuple, Optional, Dict, Any, List | |
from dataclasses import dataclass | |
import random | |
from datetime import datetime, timedelta | |
import os | |
from qwen_agent.agents import Assistant | |
from qwen_agent.gui.web_ui import WebUI | |
class PatientMetadata: | |
age: int | |
smoking_status: str | |
family_history: bool | |
menopause_status: str | |
previous_mammogram: bool | |
breast_density: str | |
hormone_therapy: bool | |
class AnalysisResult: | |
has_tumor: bool | |
tumor_size: str | |
confidence: float | |
metadata: PatientMetadata | |
class BreastCancerAgent(Assistant): | |
def __init__(self): | |
super().__init__( | |
llm={ | |
'model': os.environ.get("MODELNAME", "qwen-vl-chat"), | |
'generate_cfg': { | |
'max_input_tokens': 32768, | |
'max_retries': 10, | |
'temperature': float(os.environ.get("T", 0.001)), | |
'repetition_penalty': float(os.environ.get("R", 1.0)), | |
"top_k": int(os.environ.get("K", 20)), | |
"top_p": float(os.environ.get("P", 0.8)), | |
} | |
}, | |
name='Breast Cancer Analyzer', | |
description='Medical imaging analysis system specializing in breast cancer detection and reporting.', | |
system_message='You are an expert medical imaging system analyzing breast cancer scans. Provide clear, accurate, and professional analysis.' | |
) | |
print("Initializing system...") | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
self._init_vision_models() | |
print("Initialization complete!") | |
def _init_vision_models(self) -> None: | |
"""Initialize vision models for abnormality detection and size measurement.""" | |
print("Loading detection models...") | |
self.tumor_detector = AutoModelForImageClassification.from_pretrained( | |
"SIATCN/vit_tumor_classifier" | |
).to(self.device).eval() | |
self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") | |
self.size_detector = AutoModelForImageClassification.from_pretrained( | |
"SIATCN/vit_tumor_radius_detection_finetuned" | |
).to(self.device).eval() | |
self.size_processor = AutoImageProcessor.from_pretrained( | |
"SIATCN/vit_tumor_radius_detection_finetuned" | |
) | |
def _process_image(self, image: Image.Image) -> Image.Image: | |
"""Process input image for model consumption.""" | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
return image.resize((224, 224)) | |
def _analyze_image(self, image: Image.Image) -> AnalysisResult: | |
"""Perform abnormality detection and size measurement.""" | |
metadata = self._generate_synthetic_metadata() | |
tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device) | |
tumor_outputs = self.tumor_detector(**tumor_inputs) | |
tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu() | |
has_tumor = tumor_probs[1] > tumor_probs[0] | |
confidence = float(tumor_probs[1] if has_tumor else tumor_probs[0]) | |
size_inputs = self.size_processor(image, return_tensors="pt").to(self.device) | |
size_outputs = self.size_detector(**size_inputs) | |
size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu() | |
sizes = ["no-tumor", "0.5", "1.0", "1.5"] | |
tumor_size = sizes[size_pred.argmax().item()] | |
return AnalysisResult(has_tumor, tumor_size, confidence, metadata) | |
def _generate_synthetic_metadata(self) -> PatientMetadata: | |
"""Generate realistic patient metadata for breast cancer screening.""" | |
age = random.randint(40, 75) | |
smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]) | |
family_history = random.choice([True, False]) | |
menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal" | |
previous_mammogram = random.choice([True, False]) | |
breast_density = random.choice(["A: Almost entirely fatty", | |
"B: Scattered fibroglandular", | |
"C: Heterogeneously dense", | |
"D: Extremely dense"]) | |
hormone_therapy = random.choice([True, False]) | |
return PatientMetadata( | |
age=age, | |
smoking_status=smoking_status, | |
family_history=family_history, | |
menopause_status=menopause_status, | |
previous_mammogram=previous_mammogram, | |
breast_density=breast_density, | |
hormone_therapy=hormone_therapy | |
) | |
def run(self, image_path: str) -> str: | |
"""Run analysis on an image.""" | |
try: | |
image = Image.open(image_path) | |
processed_image = self._process_image(image) | |
analysis = self._analyze_image(processed_image) | |
report = f"""MICROWAVE IMAGING ANALYSIS: | |
• Detection: {'Positive' if analysis.has_tumor else 'Negative'} | |
• Size: {analysis.tumor_size} cm | |
PATIENT INFO: | |
• Age: {analysis.metadata.age} years | |
• Risk Factors: {', '.join([ | |
'family history' if analysis.metadata.family_history else '', | |
analysis.metadata.smoking_status.lower(), | |
'hormone therapy' if analysis.metadata.hormone_therapy else '', | |
]).strip(', ')} | |
REPORT: | |
{'Abnormal scan showing potential mass.' if analysis.has_tumor else 'Normal scan with no significant findings.'} | |
Confidence level: {analysis.confidence:.1%} | |
RECOMMENDATION: | |
{('Immediate follow-up imaging recommended.' if analysis.tumor_size in ['1.0', '1.5'] else 'Follow-up imaging in 6 months recommended.') if analysis.has_tumor else 'Continue routine screening per protocol.'}""" | |
return report | |
except Exception as e: | |
return f"Error during analysis: {str(e)}" | |
def run_interface(): | |
"""Create and run the WebUI interface.""" | |
agent = BreastCancerAgent() | |
chatbot_config = { | |
'user.name': 'Medical Staff', | |
'input.placeholder': 'Upload a breast microwave image for analysis...', | |
'prompt.suggestions': [ | |
{'text': 'Can you analyze this mammogram?'}, | |
{'text': 'What should I look for in the results?'}, | |
{'text': 'How reliable is the detection?'} | |
] | |
} | |
app = WebUI(agent, chatbot_config=chatbot_config) | |
app.run(share=True, concurrency_limit=80) | |
if __name__ == "__main__": | |
print("Starting application...") | |
run_interface() |