import gradio as gr from transformers import ( AutoImageProcessor, AutoModelForImageClassification, pipeline ) from PIL import Image import torch import random import json import time class MedicalImageAnalysisSystem: def __init__(self): print("Initializing system...") # Check for CUDA availability self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load models one at a time with progress messages print("Loading tumor classifier...") self.tumor_classifier_model = AutoModelForImageClassification.from_pretrained( "SIATCN/vit_tumor_classifier", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.tumor_classifier_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") print("Loading tumor radius detector...") self.radius_model = AutoModelForImageClassification.from_pretrained( "SIATCN/vit_tumor_radius_detection_finetuned", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.radius_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_radius_detection_finetuned") print("Loading language model...") # Using a smaller model for faster inference self.llm = pipeline( "text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto", model_kwargs={"low_cpu_mem_usage": True} ) # Set models to evaluation mode self.tumor_classifier_model.eval() self.radius_model.eval() print("System initialization complete!") def generate_synthetic_metadata(self): return { "age": random.randint(25, 85), "gender": random.choice(["Male", "Female"]), "smoking_status": random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]), "drinking_status": random.choice(["Non-drinker", "Social Drinker", "Regular Drinker"]), "medications": random.sample([ "Lisinopril", "Metformin", "Levothyroxine", "Amlodipine", "Metoprolol", "Omeprazole", "Simvastatin", "Losartan" ], random.randint(0, 3)) } def process_image(self, image): if isinstance(image, str): image = Image.open(image) if image.mode != 'RGB': image = image.convert('RGB') return image.resize((224, 224)) @torch.no_grad() def predict_tumor_presence(self, processed_image): inputs = self.tumor_classifier_processor(processed_image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.tumor_classifier_model(**inputs) predictions = torch.softmax(outputs.logits, dim=-1) probs = predictions[0].cpu().tolist() # Return just the predicted class instead of probabilities return "tumor" if probs[1] > probs[0] else "non-tumor" @torch.no_grad() def predict_tumor_radius(self, processed_image): inputs = self.radius_processor(processed_image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.radius_model(**inputs) predictions = outputs.logits.softmax(dim=-1) predicted_label = predictions.argmax().item() class_names = ["no-tumor", "0.5", "1.0", "1.5"] # Return just the radius without confidence return class_names[predicted_label] def generate_llm_interpretation(self, tumor_presence, tumor_radius, metadata): prompt = f"""<|system|>You are a medical AI assistant. Provide a clear and concise medical interpretation. <|user|>Analyze the following medical findings: Image Analysis: - Tumor Detection: {tumor_presence} - Tumor Size: {tumor_radius} cm Patient Profile: - Age: {metadata['age']} years - Gender: {metadata['gender']} - Smoking: {metadata['smoking_status']} - Alcohol: {metadata['drinking_status']} - Current Medications: {', '.join(metadata['medications']) if metadata['medications'] else 'None'} Provide a brief: 1. Key findings 2. Clinical recommendations 3. Follow-up plan <|assistant|>""" response = self.llm( prompt, max_new_tokens=300, temperature=0.7, do_sample=True, top_p=0.95, num_return_sequences=1 ) return response[0]['generated_text'].split("<|assistant|>")[-1].strip() def analyze_image(self, image): try: # Process image and generate metadata processed_image = self.process_image(image) metadata = self.generate_synthetic_metadata() # Get predictions tumor_presence = self.predict_tumor_presence(processed_image) tumor_radius = self.predict_tumor_radius(processed_image) # Generate interpretation interpretation = self.generate_llm_interpretation( tumor_presence, tumor_radius, metadata ) # Format results results = { "metadata": metadata, "tumor_presence": tumor_presence, "tumor_radius": tumor_radius, "interpretation": interpretation } return self.format_results(results) except Exception as e: return f"Error: {str(e)}" def format_results(self, results): return f""" Patient Information: {json.dumps(results['metadata'], indent=2)} Image Analysis Results: - Tumor Detection: {results['tumor_presence']} - Tumor Size: {results['tumor_radius']} cm Medical Assessment: {results['interpretation']} """ def create_interface(): system = MedicalImageAnalysisSystem() iface = gr.Interface( fn=system.analyze_image, inputs=[ gr.Image(type="pil", label="Upload Medical Image") ], outputs=[ gr.Textbox(label="Analysis Results", lines=20) ], title="Medical Image Analysis System", description="Upload a medical image for tumor analysis and recommendations.", theme=gr.themes.Base() ) return iface if __name__ == "__main__": print("Starting application...") iface = create_interface() iface.queue() iface.launch(debug=True, share=True)