import torch import torchvision.transforms as transforms import gradio as gr from PIL import Image from model import SimpleCNN def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = Image.fromarray(image) # Ensure that the image has three channels (RGB) if image.mode != 'RGB': image = image.convert('RGB') image = transform(image) image = image.unsqueeze(0) return image def predict_image(model, image): with torch.no_grad(): output = model(image) _, predicted = torch.max(output.data, 1) return predicted.item() def main(): model = SimpleCNN() model.load_state_dict(torch.load('cifar10_model.pth')) # Set the model to evaluation mode model.eval() iface = gr.Interface( fn=lambda img: predict_image(model, preprocess_image(img)), inputs=gr.Image(), outputs="label", live=True, ) iface.launch() if __name__ == "__main__": main()