from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image import io import torch from torchvision import models, transforms # 加载预训练的ResNet-50模型 model = models.resnet50(pretrained=True) model.eval() # 设置模型为评估模式 # 图像预处理 preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 创建FastAPI应用实例 app = FastAPI() @app.post("/predict/") async def predict(file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # 预处理图片 input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0) # 添加批处理维度 with torch.no_grad(): output = model(input_batch) # 获取预测结果 _, predicted_idx = torch.max(output, 1) # 可以在此处添加代码来获取类别名称,这里只返回索引 return JSONResponse(content={"predicted_class": int(predicted_idx[0])}) # 运行服务 if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)