import torch import torchvision.transforms as transforms import gradio as gr from PIL import Image from model import SimpleCNN def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = Image.fromarray(image) image = transform(image) image = image.unsqueeze(0) return image def predict_image(model, image): model.eval() with torch.no_grad(): output = model(image) _, predicted = torch.max(output.data, 1) return predicted.item() def main(): model = SimpleCNN() model.load_state_dict(torch.load('cifar10_model.pth')) model.eval() iface = gr.Interface( fn=lambda img: predict_image(model, preprocess_image(img)), inputs=gr.Image(), outputs="label", live=True, ) iface.launch(share=True) if __name__ == "__main__": main()