BreastCare / app2.py
AliArshad's picture
Rename app.py to app2.py
936aa8d verified
raw
history blame
4.42 kB
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
import json
class SinogramAnalysisSystem:
def __init__(self):
print("Initializing system...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load analysis models
print("Loading tumor detection models...")
self.tumor_classifier = AutoModelForImageClassification.from_pretrained(
"SIATCN/vit_tumor_classifier"
).to(self.device)
self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
self.size_classifier = AutoModelForImageClassification.from_pretrained(
"SIATCN/vit_tumor_radius_detection_finetuned"
).to(self.device)
self.size_processor = AutoImageProcessor.from_pretrained(
"SIATCN/vit_tumor_radius_detection_finetuned"
)
# Load Hymba model
print("Loading Hymba model...")
repo_name = "nvidia/Hymba-1.5B-Base"
self.tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
self.llm = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
self.llm = self.llm.to(self.device).to(torch.bfloat16)
print("System ready!")
def process_sinogram(self, image):
if isinstance(image, str):
image = Image.open(image)
if image.mode != 'RGB':
image = image.convert('RGB')
return image.resize((224, 224))
@torch.no_grad()
def analyze_sinogram(self, processed_image):
# Detect tumor
inputs = self.tumor_processor(processed_image, return_tensors="pt").to(self.device)
outputs = self.tumor_classifier(**inputs)
tumor_present = outputs.logits.softmax(dim=-1)[0].cpu()
has_tumor = tumor_present[1] > tumor_present[0]
# Assess size
size_inputs = self.size_processor(processed_image, return_tensors="pt").to(self.device)
size_outputs = self.size_classifier(**size_inputs)
size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
sizes = ["no-tumor", "0.5", "1.0", "1.5"]
tumor_size = sizes[size_pred.argmax().item()]
return has_tumor, tumor_size
def generate_report(self, tumor_present, tumor_size):
prompt = f"""As a medical professional, provide a brief analysis of these sinogram findings:
Findings:
- Tumor Detection: {'Positive' if tumor_present else 'Negative'}
- Tumor Size: {tumor_size} cm
Please provide:
1. Brief interpretation
2. Clinical recommendations
3. Follow-up plan"""
# Generate response using Hymba
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.llm.generate(
**inputs,
max_length=512,
do_sample=True,
temperature=0.7,
use_cache=True
)
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
def analyze_image(self, image):
try:
# Process sinogram
processed = self.process_sinogram(image)
tumor_present, tumor_size = self.analyze_sinogram(processed)
# Generate medical report
report = self.generate_report(tumor_present, tumor_size)
# Format results
return f"""
SINOGRAM ANALYSIS:
• Tumor Detection: {'Positive' if tumor_present else 'Negative'}
• Size Assessment: {tumor_size} cm
MEDICAL REPORT:
{report}
"""
except Exception as e:
return f"Error during analysis: {str(e)}"
def create_interface():
system = SinogramAnalysisSystem()
iface = gr.Interface(
fn=system.analyze_image,
inputs=[
gr.Image(type="pil", label="Upload Sinogram")
],
outputs=[
gr.Textbox(label="Analysis Results", lines=15)
],
title="Sinogram Analysis System",
description="Upload a sinogram for tumor detection and medical assessment."
)
return iface
if __name__ == "__main__":
interface = create_interface()
interface.launch(debug=True, share=True)