MakiAi commited on
Commit
ab5a524
·
verified ·
1 Parent(s): 9b7e7d0

Upload server_fastapi.py

Browse files
Files changed (1) hide show
  1. server_fastapi.py +263 -0
server_fastapi.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API server for TTS
3
+ """
4
+ import argparse
5
+ import os
6
+ import sys
7
+ from io import BytesIO
8
+ from typing import Dict, Optional, Union
9
+ from urllib.parse import unquote
10
+
11
+ import GPUtil
12
+ import psutil
13
+ import torch
14
+ import uvicorn
15
+ from fastapi import FastAPI, HTTPException, Query, Request, status
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import FileResponse, Response
18
+ from scipy.io import wavfile
19
+
20
+ from common.constants import (
21
+ DEFAULT_ASSIST_TEXT_WEIGHT,
22
+ DEFAULT_LENGTH,
23
+ DEFAULT_LINE_SPLIT,
24
+ DEFAULT_NOISE,
25
+ DEFAULT_NOISEW,
26
+ DEFAULT_SDP_RATIO,
27
+ DEFAULT_SPLIT_INTERVAL,
28
+ DEFAULT_STYLE,
29
+ DEFAULT_STYLE_WEIGHT,
30
+ Languages,
31
+ )
32
+ from common.log import logger
33
+ from common.tts_model import Model, ModelHolder
34
+ from config import config
35
+
36
+ ln = config.server_config.language
37
+
38
+
39
+ def raise_validation_error(msg: str, param: str):
40
+ logger.warning(f"Validation error: {msg}")
41
+ raise HTTPException(
42
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
43
+ detail=[dict(type="invalid_params", msg=msg, loc=["query", param])],
44
+ )
45
+
46
+
47
+ class AudioResponse(Response):
48
+ media_type = "audio/wav"
49
+
50
+
51
+ def load_models(model_holder: ModelHolder):
52
+ model_holder.models = []
53
+ for model_name, model_paths in model_holder.model_files_dict.items():
54
+ model = Model(
55
+ model_path=model_paths[0],
56
+ config_path=os.path.join(model_holder.root_dir, model_name, "config.json"),
57
+ style_vec_path=os.path.join(
58
+ model_holder.root_dir, model_name, "style_vectors.npy"
59
+ ),
60
+ device=model_holder.device,
61
+ )
62
+ model.load_net_g()
63
+ model_holder.models.append(model)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
69
+ parser.add_argument(
70
+ "--dir", "-d", type=str, help="Model directory", default=config.assets_root
71
+ )
72
+ args = parser.parse_args()
73
+
74
+ if args.cpu:
75
+ device = "cpu"
76
+ else:
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+ model_dir = args.dir
80
+ model_holder = ModelHolder(model_dir, device)
81
+ if len(model_holder.model_names) == 0:
82
+ logger.error(f"Models not found in {model_dir}.")
83
+ sys.exit(1)
84
+
85
+ logger.info("Loading models...")
86
+ load_models(model_holder)
87
+ limit = config.server_config.limit
88
+ app = FastAPI()
89
+ allow_origins = config.server_config.origins
90
+ if allow_origins:
91
+ logger.warning(
92
+ f"CORS allow_origins={config.server_config.origins}. If you don't want, modify config.yml"
93
+ )
94
+ app.add_middleware(
95
+ CORSMiddleware,
96
+ allow_origins=config.server_config.origins,
97
+ allow_credentials=True,
98
+ allow_methods=["*"],
99
+ allow_headers=["*"],
100
+ )
101
+ app.logger = logger
102
+
103
+ @app.get("/voice", response_class=AudioResponse)
104
+ async def voice(
105
+ request: Request,
106
+ text: str = Query(..., min_length=1, max_length=limit, description=f"セリフ"),
107
+ encoding: str = Query(None, description="textをURLデコードする(ex, `utf-8`)"),
108
+ model_id: int = Query(0, description="モデルID。`GET /models/info`のkeyの値を指定ください"),
109
+ speaker_name: str = Query(
110
+ None, description="話者名(speaker_idより優先)。esd.listの2列目の文字列を指定"
111
+ ),
112
+ speaker_id: int = Query(
113
+ 0, description="話者ID。model_assets>[model]>config.json内のspk2idを確認"
114
+ ),
115
+ sdp_ratio: float = Query(
116
+ DEFAULT_SDP_RATIO,
117
+ description="SDP(Stochastic Duration Predictor)/DP混合比。比率が高くなるほどトーンのばらつきが大きくなる",
118
+ ),
119
+ noise: float = Query(DEFAULT_NOISE, description="サンプルノイズの割合。大きくするほどランダム性が高まる"),
120
+ noisew: float = Query(
121
+ DEFAULT_NOISEW, description="SDPノイズ。大きくするほど発音の間隔にばらつきが出やすくなる"
122
+ ),
123
+ length: float = Query(
124
+ DEFAULT_LENGTH, description="話速。基準は1で大きくするほど音声は長くなり読み上げが遅まる"
125
+ ),
126
+ language: Languages = Query(ln, description=f"textの言語"),
127
+ auto_split: bool = Query(DEFAULT_LINE_SPLIT, description="改行で分けて生成"),
128
+ split_interval: float = Query(
129
+ DEFAULT_SPLIT_INTERVAL, description="分けた場合に挟む無音の長さ(秒)"
130
+ ),
131
+ assist_text: Optional[str] = Query(
132
+ None, description="このテキストの読み上げと似た声音・感情になりやすくなる。ただし抑揚やテンポ等が犠牲になる傾向がある"
133
+ ),
134
+ assist_text_weight: float = Query(
135
+ DEFAULT_ASSIST_TEXT_WEIGHT, description="assist_textの強さ"
136
+ ),
137
+ style: Optional[Union[int, str]] = Query(DEFAULT_STYLE, description="スタイル"),
138
+ style_weight: float = Query(DEFAULT_STYLE_WEIGHT, description="スタイルの強さ"),
139
+ reference_audio_path: Optional[str] = Query(None, description="スタイルを音声ファイルで行う"),
140
+ ):
141
+ """Infer text to speech(テキストから感情付き音声を生成する)"""
142
+ logger.info(
143
+ f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
144
+ )
145
+ if model_id >= len(model_holder.models): # /models/refresh があるためQuery(le)で表現不可
146
+ raise_validation_error(f"model_id={model_id} not found", "model_id")
147
+
148
+ model = model_holder.models[model_id]
149
+ if speaker_name is None:
150
+ if speaker_id not in model.id2spk.keys():
151
+ raise_validation_error(
152
+ f"speaker_id={speaker_id} not found", "speaker_id"
153
+ )
154
+ else:
155
+ if speaker_name not in model.spk2id.keys():
156
+ raise_validation_error(
157
+ f"speaker_name={speaker_name} not found", "speaker_name"
158
+ )
159
+ speaker_id = model.spk2id[speaker_name]
160
+ if style not in model.style2id.keys():
161
+ raise_validation_error(f"style={style} not found", "style")
162
+ if encoding is not None:
163
+ text = unquote(text, encoding=encoding)
164
+ sr, audio = model.infer(
165
+ text=text,
166
+ language=language,
167
+ sid=speaker_id,
168
+ reference_audio_path=reference_audio_path,
169
+ sdp_ratio=sdp_ratio,
170
+ noise=noise,
171
+ noisew=noisew,
172
+ length=length,
173
+ line_split=auto_split,
174
+ split_interval=split_interval,
175
+ assist_text=assist_text,
176
+ assist_text_weight=assist_text_weight,
177
+ use_assist_text=bool(assist_text),
178
+ style=style,
179
+ style_weight=style_weight,
180
+ )
181
+ logger.success("Audio data generated and sent successfully")
182
+ with BytesIO() as wavContent:
183
+ wavfile.write(wavContent, sr, audio)
184
+ return Response(content=wavContent.getvalue(), media_type="audio/wav")
185
+
186
+ @app.get("/models/info")
187
+ def get_loaded_models_info():
188
+ """ロードされたモデル情報の取得"""
189
+
190
+ result: Dict[str, Dict] = dict()
191
+ for model_id, model in enumerate(model_holder.models):
192
+ result[str(model_id)] = {
193
+ "config_path": model.config_path,
194
+ "model_path": model.model_path,
195
+ "device": model.device,
196
+ "spk2id": model.spk2id,
197
+ "id2spk": model.id2spk,
198
+ "style2id": model.style2id,
199
+ }
200
+ return result
201
+
202
+ @app.post("/models/refresh")
203
+ def refresh():
204
+ """モデルをパスに追加/削除した際などに読み込ませる"""
205
+ model_holder.refresh()
206
+ load_models(model_holder)
207
+ return get_loaded_models_info()
208
+
209
+ @app.get("/status")
210
+ def get_status():
211
+ """実行環境のステータスを取得"""
212
+ cpu_percent = psutil.cpu_percent(interval=1)
213
+ memory_info = psutil.virtual_memory()
214
+ memory_total = memory_info.total
215
+ memory_available = memory_info.available
216
+ memory_used = memory_info.used
217
+ memory_percent = memory_info.percent
218
+ gpuInfo = []
219
+ devices = ["cpu"]
220
+ for i in range(torch.cuda.device_count()):
221
+ devices.append(f"cuda:{i}")
222
+ gpus = GPUtil.getGPUs()
223
+ for gpu in gpus:
224
+ gpuInfo.append(
225
+ {
226
+ "gpu_id": gpu.id,
227
+ "gpu_load": gpu.load,
228
+ "gpu_memory": {
229
+ "total": gpu.memoryTotal,
230
+ "used": gpu.memoryUsed,
231
+ "free": gpu.memoryFree,
232
+ },
233
+ }
234
+ )
235
+ return {
236
+ "devices": devices,
237
+ "cpu_percent": cpu_percent,
238
+ "memory_total": memory_total,
239
+ "memory_available": memory_available,
240
+ "memory_used": memory_used,
241
+ "memory_percent": memory_percent,
242
+ "gpu": gpuInfo,
243
+ }
244
+
245
+ @app.get("/tools/get_audio", response_class=AudioResponse)
246
+ def get_audio(
247
+ request: Request, path: str = Query(..., description="local wav path")
248
+ ):
249
+ """wavデータを取得する"""
250
+ logger.info(
251
+ f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
252
+ )
253
+ if not os.path.isfile(path):
254
+ raise_validation_error(f"path={path} not found", "path")
255
+ if not path.lower().endswith(".wav"):
256
+ raise_validation_error(f"wav file not found in {path}", "path")
257
+ return FileResponse(path=path, media_type="audio/wav")
258
+
259
+ logger.info(f"server listen: http://127.0.0.1:{config.server_config.port}")
260
+ logger.info(f"API docs: http://127.0.0.1:{config.server_config.port}/docs")
261
+ uvicorn.run(
262
+ app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
263
+ )