Charles95 commited on
Commit
5c38c6f
·
verified ·
1 Parent(s): 482eadc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
+ import io
5
+ import torch
6
+ from torchvision import models, transforms
7
+
8
+ # 加载预训练的ResNet-50模型
9
+ model = models.resnet50(pretrained=True)
10
+ model.eval() # 设置模型为评估模式
11
+
12
+ # 图像预处理
13
+ preprocess = transforms.Compose([
14
+ transforms.Resize(256),
15
+ transforms.CenterCrop(224),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
+ ])
19
+
20
+ # 创建FastAPI应用实例
21
+ app = FastAPI()
22
+
23
+ @app.post("/predict/")
24
+ async def predict(file: UploadFile = File(...)):
25
+ contents = await file.read()
26
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
27
+
28
+ # 预处理图片
29
+ input_tensor = preprocess(image)
30
+ input_batch = input_tensor.unsqueeze(0) # 添加批处理维度
31
+
32
+ with torch.no_grad():
33
+ output = model(input_batch)
34
+
35
+ # 获取预测结果
36
+ _, predicted_idx = torch.max(output, 1)
37
+
38
+ # 可以在此处添加代码来获取类别名称,这里只返回索引
39
+ return JSONResponse(content={"predicted_class": int(predicted_idx[0])})
40
+
41
+ # 运行服务
42
+ if __name__ == "__main__":
43
+ import uvicorn
44
+ uvicorn.run(app, host="0.0.0.0", port=8000)