import gradio as gr from transformers import ( AutoImageProcessor, AutoModelForImageClassification, pipeline ) from PIL import Image import torch import random import json import time class MedicalImageAnalysisSystem: def __init__(self): print("Initializing system...") # Check for CUDA availability self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load models one at a time with progress messages print("Loading tumor classifier...") self.tumor_classifier_model = AutoModelForImageClassification.from_pretrained( "SIATCN/vit_tumor_classifier", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.tumor_classifier_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") print("Loading tumor radius detector...") self.radius_model = AutoModelForImageClassification.from_pretrained( "SIATCN/vit_tumor_radius_detection_finetuned", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.radius_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_radius_detection_finetuned") print("Loading language model...") # Using a smaller model for faster inference self.llm = pipeline( "text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto", model_kwargs={"low_cpu_mem_usage": True} ) print("System initialization complete!") def generate_synthetic_metadata(self): return { "age": random.randint(25, 85), "gender": random.choice(["Male", "Female"]), "smoking_status": random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]), "drinking_status": random.choice(["Non-drinker", "Social Drinker", "Regular Drinker"]), "medications": random.sample([ "Lisinopril", "Metformin", "Levothyroxine", "Amlodipine", "Metoprolol", "Omeprazole", "Simvastatin", "Losartan" ], random.randint(0, 3)) } def process_image(self, image): if isinstance(image, str): image = Image.open(image) if image.mode != 'RGB': image = image.convert('RGB') return image.resize((224, 224)) @torch.no_grad() # Disable gradient computation for inference def predict_tumor_presence(self, processed_image): inputs = self.tumor_classifier_processor(processed_image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to correct device outputs = self.tumor_classifier_model(**inputs) predictions = torch.softmax(outputs.logits, dim=-1) probs = predictions[0].cpu().tolist() # Move back to CPU for numpy operations return { "non-tumor": float(probs[0]), "tumor": float(probs[1]) } @torch.no_grad() # Disable gradient computation for inference def predict_tumor_radius(self, processed_image): inputs = self.radius_processor(processed_image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to correct device outputs = self.radius_model(**inputs) predictions = outputs.logits.softmax(dim=-1) predicted_label = predictions.argmax().item() confidence = predictions[0][predicted_label].cpu().item() # Move back to CPU class_names = ["no-tumor", "0.5", "1.0", "1.5"] return { "radius": class_names[predicted_label], "confidence": float(confidence) } def generate_llm_interpretation(self, tumor_presence, tumor_radius, metadata): prompt = f"""<|system|>You are a medical AI assistant. Be concise but thorough. <|user|>Analyze these results: Tumor Detection: {json.dumps(tumor_presence)} Tumor Radius: {json.dumps(tumor_radius)} Patient: {metadata['age']}y/o {metadata['gender']}, {metadata['smoking_status']}, {metadata['drinking_status']} Medications: {', '.join(metadata['medications']) if metadata['medications'] else 'None'} Provide: 1. Key findings 2. Risk assessment 3. Recommendations <|assistant|>""" response = self.llm( prompt, max_new_tokens=300, # Reduced for faster response temperature=0.7, do_sample=True, top_p=0.95, num_return_sequences=1 ) return response[0]['generated_text'].split("<|assistant|>")[-1].strip() def analyze_image(self, image): try: # Add progress updates yield "Processing image..." processed_image = self.process_image(image) yield "Generating patient metadata..." metadata = self.generate_synthetic_metadata() yield "Analyzing tumor presence..." tumor_presence = self.predict_tumor_presence(processed_image) yield "Analyzing tumor radius..." tumor_radius = self.predict_tumor_radius(processed_image) yield "Generating medical interpretation..." interpretation = self.generate_llm_interpretation( tumor_presence, tumor_radius, metadata ) # Final results result = { "metadata": metadata, "tumor_presence": tumor_presence, "tumor_radius": tumor_radius, "interpretation": interpretation } yield self.format_results(result) except Exception as e: yield f"Error: {str(e)}" def format_results(self, results): return f""" Patient Metadata: {json.dumps(results['metadata'], indent=2)} Tumor Presence Analysis: {json.dumps(results['tumor_presence'], indent=2)} Tumor Radius Analysis: {json.dumps(results['tumor_radius'], indent=2)} Medical Interpretation and Recommendations: {results['interpretation']} """ def create_interface(): system = MedicalImageAnalysisSystem() iface = gr.Interface( fn=system.analyze_image, inputs=[ gr.Image(type="pil", label="Upload Medical Image") ], outputs=[ gr.Textbox(label="Analysis Results", lines=20) ], title="Medical Image Analysis System", description="Upload a medical image for tumor analysis and recommendations.", theme=gr.themes.Base(), flagging=False ) return iface if __name__ == "__main__": print("Starting application...") iface = create_interface() iface.queue() # Enable queuing for better handling of multiple requests iface.launch(debug=True, share=True)