Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import torch | |
from typing import Tuple, Optional, Dict, Any | |
from dataclasses import dataclass | |
import random | |
class PatientMetadata: | |
age: int | |
smoking_status: str | |
family_history: bool | |
menopause_status: str | |
previous_mammogram: bool | |
breast_density: str | |
hormone_therapy: bool | |
class AnalysisResult: | |
has_tumor: bool | |
tumor_size: str | |
metadata: PatientMetadata | |
class BreastSinogramAnalyzer: | |
def __init__(self): | |
"""Initialize the analyzer with required models.""" | |
print("Initializing system...") | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
self._init_vision_models() | |
self._init_llm() | |
print("Initialization complete!") | |
def _init_vision_models(self) -> None: | |
"""Initialize vision models for abnormality detection and size measurement.""" | |
print("Loading detection models...") | |
self.tumor_detector = AutoModelForImageClassification.from_pretrained( | |
"SIATCN/vit_tumor_classifier" | |
).to(self.device).eval() | |
self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") | |
self.size_detector = AutoModelForImageClassification.from_pretrained( | |
"SIATCN/vit_tumor_radius_detection_finetuned" | |
).to(self.device).eval() | |
self.size_processor = AutoImageProcessor.from_pretrained( | |
"SIATCN/vit_tumor_radius_detection_finetuned" | |
) | |
def _init_llm(self) -> None: | |
"""Initialize the Qwen language model for report generation.""" | |
print("Loading Qwen language model...") | |
self.model_name = "Qwen/QwQ-32B-Preview" | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
def _generate_synthetic_metadata(self) -> PatientMetadata: | |
"""Generate realistic patient metadata for breast cancer screening.""" | |
age = random.randint(40, 75) | |
smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]) | |
family_history = random.choice([True, False]) | |
menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal" | |
previous_mammogram = random.choice([True, False]) | |
breast_density = random.choice(["A: Almost entirely fatty", | |
"B: Scattered fibroglandular", | |
"C: Heterogeneously dense", | |
"D: Extremely dense"]) | |
hormone_therapy = random.choice([True, False]) | |
return PatientMetadata( | |
age=age, | |
smoking_status=smoking_status, | |
family_history=family_history, | |
menopause_status=menopause_status, | |
previous_mammogram=previous_mammogram, | |
breast_density=breast_density, | |
hormone_therapy=hormone_therapy | |
) | |
def _process_image(self, image: Image.Image) -> Image.Image: | |
"""Process input image for model consumption.""" | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
return image.resize((224, 224)) | |
def _analyze_image(self, image: Image.Image) -> AnalysisResult: | |
"""Perform abnormality detection and size measurement.""" | |
metadata = self._generate_synthetic_metadata() | |
# Detect abnormality | |
tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device) | |
tumor_outputs = self.tumor_detector(**tumor_inputs) | |
tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu() | |
has_tumor = tumor_probs[1] > tumor_probs[0] | |
# Measure size if tumor detected | |
size_inputs = self.size_processor(image, return_tensors="pt").to(self.device) | |
size_outputs = self.size_detector(**size_inputs) | |
size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu() | |
sizes = ["no-tumor", "0.5", "1.0", "1.5"] | |
tumor_size = sizes[size_pred.argmax().item()] | |
return AnalysisResult(has_tumor, tumor_size, metadata) | |
def _generate_medical_report(self, analysis: AnalysisResult) -> str: | |
"""Generate a clear medical report using Qwen.""" | |
try: | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a radiologist providing clear and straightforward medical reports. Focus on clarity and actionable recommendations." | |
}, | |
{ | |
"role": "user", | |
"content": f"""Generate a clear medical report for this breast imaging scan: | |
Scan Results: | |
- Finding: {'Abnormal area detected' if analysis.has_tumor else 'No abnormalities detected'} | |
{f'- Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''} | |
Patient Information: | |
- Age: {analysis.metadata.age} years | |
- Risk factors: {', '.join([ | |
'family history of breast cancer' if analysis.metadata.family_history else '', | |
f'{analysis.metadata.smoking_status.lower()}', | |
'currently on hormone therapy' if analysis.metadata.hormone_therapy else '' | |
]).strip(', ')} | |
Please provide: | |
1. A clear interpretation of the findings | |
2. A specific recommendation for next steps""" | |
} | |
] | |
text = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) | |
generated_ids = self.model.generate( | |
**model_inputs, | |
max_new_tokens=128, | |
temperature=0.3, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
do_sample=True | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
if len(response.split()) >= 10: | |
return f"""FINDINGS AND RECOMMENDATIONS: | |
{response}""" | |
return self._generate_fallback_report(analysis) | |
except Exception as e: | |
print(f"Error in report generation: {str(e)}") | |
return self._generate_fallback_report(analysis) | |
def _generate_fallback_report(self, analysis: AnalysisResult) -> str: | |
"""Generate a clear fallback report.""" | |
if analysis.has_tumor: | |
return f"""FINDINGS AND RECOMMENDATIONS: | |
Finding: An abnormal area measuring {analysis.tumor_size} cm was detected during the scan. | |
Recommendation: {'An immediate follow-up with conventional mammogram and ultrasound is required.' if analysis.tumor_size in ['1.0', '1.5'] else 'A follow-up scan is recommended in 6 months.'}""" | |
else: | |
return """FINDINGS AND RECOMMENDATIONS: | |
Finding: No abnormal areas were detected during this scan. | |
Recommendation: Continue with routine screening as per standard guidelines.""" | |
def analyze(self, image: Image.Image) -> str: | |
"""Main analysis pipeline.""" | |
try: | |
processed_image = self._process_image(image) | |
analysis = self._analyze_image(processed_image) | |
report = self._generate_medical_report(analysis) | |
return f"""SCAN RESULTS: | |
{'⚠️ Abnormal area detected' if analysis.has_tumor else '✓ No abnormalities detected'} | |
{f'Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''} | |
PATIENT INFORMATION: | |
• Age: {analysis.metadata.age} years | |
• Risk Factors: {', '.join([ | |
'family history of breast cancer' if analysis.metadata.family_history else '', | |
analysis.metadata.smoking_status.lower(), | |
'currently on hormone therapy' if analysis.metadata.hormone_therapy else '', | |
]).strip(', ')} | |
{report}""" | |
except Exception as e: | |
return f"Error during analysis: {str(e)}" | |
def create_interface() -> gr.Interface: | |
"""Create the Gradio interface.""" | |
analyzer = BreastSinogramAnalyzer() | |
interface = gr.Interface( | |
fn=analyzer.analyze, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Breast Image for Analysis") | |
], | |
outputs=[ | |
gr.Textbox(label="Analysis Results", lines=20) | |
], | |
title="Breast Imaging Analysis System", | |
description="Upload a breast image for analysis and medical assessment.", | |
) | |
return interface | |
if __name__ == "__main__": | |
print("Starting application...") | |
interface = create_interface() | |
interface.launch(debug=True, share=True) |