import gradio as gr from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer from PIL import Image import torch from typing import Tuple, Optional, Dict, Any from dataclasses import dataclass import random @dataclass class PatientMetadata: age: int smoking_status: str family_history: bool menopause_status: str previous_mammogram: bool breast_density: str hormone_therapy: bool @dataclass class AnalysisResult: has_tumor: bool tumor_size: str metadata: PatientMetadata class BreastSinogramAnalyzer: def __init__(self): """Initialize the analyzer with required models.""" print("Initializing system...") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") self._init_vision_models() self._init_llm() print("Initialization complete!") def _init_vision_models(self) -> None: """Initialize vision models for abnormality detection and size measurement.""" print("Loading detection models...") self.tumor_detector = AutoModelForImageClassification.from_pretrained( "SIATCN/vit_tumor_classifier" ).to(self.device).eval() self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") self.size_detector = AutoModelForImageClassification.from_pretrained( "SIATCN/vit_tumor_radius_detection_finetuned" ).to(self.device).eval() self.size_processor = AutoImageProcessor.from_pretrained( "SIATCN/vit_tumor_radius_detection_finetuned" ) def _init_llm(self) -> None: """Initialize the Qwen language model for report generation.""" print("Loading Qwen language model...") self.model_name = "Qwen/QwQ-32B-Preview" self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype="auto", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) def _generate_synthetic_metadata(self) -> PatientMetadata: """Generate realistic patient metadata for breast cancer screening.""" age = random.randint(40, 75) smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]) family_history = random.choice([True, False]) menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal" previous_mammogram = random.choice([True, False]) breast_density = random.choice(["A: Almost entirely fatty", "B: Scattered fibroglandular", "C: Heterogeneously dense", "D: Extremely dense"]) hormone_therapy = random.choice([True, False]) return PatientMetadata( age=age, smoking_status=smoking_status, family_history=family_history, menopause_status=menopause_status, previous_mammogram=previous_mammogram, breast_density=breast_density, hormone_therapy=hormone_therapy ) def _process_image(self, image: Image.Image) -> Image.Image: """Process input image for model consumption.""" if image.mode != 'RGB': image = image.convert('RGB') return image.resize((224, 224)) @torch.no_grad() def _analyze_image(self, image: Image.Image) -> AnalysisResult: """Perform abnormality detection and size measurement.""" metadata = self._generate_synthetic_metadata() # Detect abnormality tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device) tumor_outputs = self.tumor_detector(**tumor_inputs) tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu() has_tumor = tumor_probs[1] > tumor_probs[0] # Measure size if tumor detected size_inputs = self.size_processor(image, return_tensors="pt").to(self.device) size_outputs = self.size_detector(**size_inputs) size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu() sizes = ["no-tumor", "0.5", "1.0", "1.5"] tumor_size = sizes[size_pred.argmax().item()] return AnalysisResult(has_tumor, tumor_size, metadata) def _generate_medical_report(self, analysis: AnalysisResult) -> str: """Generate a clear medical report using Qwen.""" try: messages = [ { "role": "system", "content": "You are a radiologist providing clear and straightforward medical reports. Focus on clarity and actionable recommendations." }, { "role": "user", "content": f"""Generate a clear medical report for this breast imaging scan: Scan Results: - Finding: {'Abnormal area detected' if analysis.has_tumor else 'No abnormalities detected'} {f'- Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''} Patient Information: - Age: {analysis.metadata.age} years - Risk factors: {', '.join([ 'family history of breast cancer' if analysis.metadata.family_history else '', f'{analysis.metadata.smoking_status.lower()}', 'currently on hormone therapy' if analysis.metadata.hormone_therapy else '' ]).strip(', ')} Please provide: 1. A clear interpretation of the findings 2. A specific recommendation for next steps""" } ] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) generated_ids = self.model.generate( **model_inputs, max_new_tokens=128, temperature=0.3, top_p=0.9, repetition_penalty=1.1, do_sample=True ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] if len(response.split()) >= 10: return f"""FINDINGS AND RECOMMENDATIONS: {response}""" return self._generate_fallback_report(analysis) except Exception as e: print(f"Error in report generation: {str(e)}") return self._generate_fallback_report(analysis) def _generate_fallback_report(self, analysis: AnalysisResult) -> str: """Generate a clear fallback report.""" if analysis.has_tumor: return f"""FINDINGS AND RECOMMENDATIONS: Finding: An abnormal area measuring {analysis.tumor_size} cm was detected during the scan. Recommendation: {'An immediate follow-up with conventional mammogram and ultrasound is required.' if analysis.tumor_size in ['1.0', '1.5'] else 'A follow-up scan is recommended in 6 months.'}""" else: return """FINDINGS AND RECOMMENDATIONS: Finding: No abnormal areas were detected during this scan. Recommendation: Continue with routine screening as per standard guidelines.""" def analyze(self, image: Image.Image) -> str: """Main analysis pipeline.""" try: processed_image = self._process_image(image) analysis = self._analyze_image(processed_image) report = self._generate_medical_report(analysis) return f"""SCAN RESULTS: {'⚠️ Abnormal area detected' if analysis.has_tumor else '✓ No abnormalities detected'} {f'Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''} PATIENT INFORMATION: • Age: {analysis.metadata.age} years • Risk Factors: {', '.join([ 'family history of breast cancer' if analysis.metadata.family_history else '', analysis.metadata.smoking_status.lower(), 'currently on hormone therapy' if analysis.metadata.hormone_therapy else '', ]).strip(', ')} {report}""" except Exception as e: return f"Error during analysis: {str(e)}" def create_interface() -> gr.Interface: """Create the Gradio interface.""" analyzer = BreastSinogramAnalyzer() interface = gr.Interface( fn=analyzer.analyze, inputs=[ gr.Image(type="pil", label="Upload Breast Image for Analysis") ], outputs=[ gr.Textbox(label="Analysis Results", lines=20) ], title="Breast Imaging Analysis System", description="Upload a breast image for analysis and medical assessment.", ) return interface if __name__ == "__main__": print("Starting application...") interface = create_interface() interface.launch(debug=True, share=True)