Spaces:
No application file
No application file
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import io | |
import torch | |
from torchvision import models, transforms | |
# 加载预训练的ResNet-50模型 | |
model = models.resnet50(pretrained=True) | |
model.eval() # 设置模型为评估模式 | |
# 图像预处理 | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# 创建FastAPI应用实例 | |
app = FastAPI() | |
async def predict(file: UploadFile = File(...)): | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
# 预处理图片 | |
input_tensor = preprocess(image) | |
input_batch = input_tensor.unsqueeze(0) # 添加批处理维度 | |
with torch.no_grad(): | |
output = model(input_batch) | |
# 获取预测结果 | |
_, predicted_idx = torch.max(output, 1) | |
# 可以在此处添加代码来获取类别名称,这里只返回索引 | |
return JSONResponse(content={"predicted_class": int(predicted_idx[0])}) | |
# 运行服务 | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |