import gradio as gr import torch import torch.nn.functional as F import torch.nn as nn from transformers import AutoProcessor, AutoModel from peft import PeftModel from PIL import Image class ClassificationHead(nn.Module): def __init__(self, input_dim): super().__init__() self.linear = nn.Linear(input_dim, 2) def forward(self, x): return self.linear(x) def load_model(): device = torch.device("cpu") base_model = AutoModel.from_pretrained( "google/siglip-so400m-patch14-384", device_map="cpu", torch_dtype=torch.float32, attn_implementation="sdpa" ).vision_model model = PeftModel.from_pretrained(base_model, "fumo_lora", local_files_only=True) processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") head = ClassificationHead(1152) head.load_state_dict(torch.load("fumo_lora/classification_head.pth", weights_only=True, map_location="cpu")) model.eval() head.eval() return model, processor, head, device model, processor, head, device = load_model() def predict_image(image): if image is None: return "Please provide an image." try: # Process image inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model( pixel_values=inputs.pixel_values.to(device, dtype=torch.float32), ) pooled = outputs.last_hidden_state.mean(dim=1) logits = head(pooled) prob = F.softmax(logits, dim=1) fumo_prob = prob[0, 1].item() not_fumo_prob = prob[0, 0].item() result = f"Results:\n" result += f"Fumo probability: {fumo_prob:.3f}\n" result += f"Not fumo probability: {not_fumo_prob:.3f}\n" result += f"\nVerdict: {'FUMO!' if fumo_prob > 0.5 else 'Not a fumo'}" return result except Exception as e: return f"Error: {str(e)}" htmlhead = """ """ # Create Gradio interface demo = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", width=384, height=384), outputs=gr.Textbox(), title="Fumo Classifier (LoRA)", description="Drop an image to check if it's a Fumo!", examples=["examples/fumo1.jpg", "examples/fumo2.jpg", "examples/no_fumo1.jpg", "examples/no_fumo2.jpg", "examples/no_fumo3.png"], flagging_mode="manual", flagging_options=["Correct 👍", "Incorrect 👎"], head=htmlhead, ) if __name__== "__main__": has_bf16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False # or for CPU: has_bf16_cpu = torch.cpu.is_bf16_supported() if hasattr(torch.cpu, 'is_bf16_supported') else False print(f"BF16 support: {has_bf16} (GPU), {has_bf16_cpu} (CPU)") demo.launch()