Gregniuki commited on
Commit
fc66451
·
1 Parent(s): 81a4e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -13,12 +13,12 @@ enhanced_accessibility = False #@param {type:"boolean"}
13
  #@markdown ---
14
  use_gpu = False #@param {type:"boolean"}
15
 
16
- from fastapi import FastAPI, Request, Form
17
  from fastapi.responses import HTMLResponse
18
  from fastapi.responses import FileResponse
19
  from fastapi.templating import Jinja2Templates
20
  from fastapi.staticfiles import StaticFiles
21
-
22
  # ...
23
  # Mount a directory to serve static files (e.g., CSS and JavaScript)
24
 
@@ -125,7 +125,21 @@ def detect_onnx_models(path):
125
  return onnx_models[0], onnx_configs[0]
126
  else:
127
  return None
128
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @app.get("/get_speaker_id_map")
130
  async def get_speaker_id_map(selected_model: str):
131
  config = model_configurations.get(selected_model)
@@ -177,6 +191,21 @@ def load_model_configuration(models_path, config_name):
177
  return None
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
 
@@ -185,12 +214,13 @@ async def main(
185
  request: Request,
186
  text_input: str = Form(default="1, 2, 3. This is a test. Enter some text to generate."),
187
  selected_model: str = Form(...), # Selected model
188
- speaker_select: str = Form(...), # Selected speaker ID
189
  speaker: str = Form(...),
190
  speed_slider: float = Form(...),
191
  noise_scale_slider: float = Form(...),
192
  noise_scale_w_slider: float = Form(...),
193
- play: bool = Form(True)
 
194
  ):
195
  # ... (previous code)
196
 
@@ -240,8 +270,11 @@ async def main(
240
  "file_url": file_url,
241
  "text_input": text_input,
242
  "data": data,
243
- "model_names": onnx_models,
244
  "selected_model": selected_model,
 
 
 
245
  "speaker_id_map": speaker_id_map, # Make sure speaker_id_map is included here
246
  "speaker_select": speaker_select,
247
  "dynamic_content": response_html
 
13
  #@markdown ---
14
  use_gpu = False #@param {type:"boolean"}
15
 
16
+ from fastapi import FastAPI, Request, Form, Depends
17
  from fastapi.responses import HTMLResponse
18
  from fastapi.responses import FileResponse
19
  from fastapi.templating import Jinja2Templates
20
  from fastapi.staticfiles import StaticFiles
21
+ from typing import Tuple
22
  # ...
23
  # Mount a directory to serve static files (e.g., CSS and JavaScript)
24
 
 
125
  return onnx_models[0], onnx_configs[0]
126
  else:
127
  return None
128
+ # Define a dependency function to get the selected_model and selected_speaker_id on startup
129
+ def get_initial_values():
130
+ # You can set default values or load them from a configuration file here
131
+ selected_model = onnx_models[0] if onnx_models else "default_model"
132
+ selected_speaker_id = "default_speaker_id" # Default value
133
+
134
+ # Check if there are onnx models and load the speaker_id_map from the first model's config
135
+ if onnx_models:
136
+ first_model_config = model_configurations.get(onnx_models[0])
137
+ if first_model_config:
138
+ speaker_id_map = first_model_config.get("speaker_id_map")
139
+ if speaker_id_map:
140
+ selected_speaker_id = next(iter(speaker_id_map)) # Get the first speaker_id
141
+
142
+ return selected_model, selected_speaker_id
143
  @app.get("/get_speaker_id_map")
144
  async def get_speaker_id_map(selected_model: str):
145
  config = model_configurations.get(selected_model)
 
191
  return None
192
 
193
 
194
+ # Define a dependency function to get the selected_model and selected_speaker_id on startup
195
+ def get_initial_values() -> Tuple[str, str]:
196
+ # You can set default values or load them from a configuration file here
197
+ selected_model = onnx_models[0] if onnx_models else "default_model"
198
+ selected_speaker_id = "default_speaker_id" # Default value
199
+
200
+ # Check if there are onnx models and load the speaker_id_map from the first model's config
201
+ if onnx_models:
202
+ first_model_config = model_configurations.get(onnx_models[0])
203
+ if first_model_config:
204
+ speaker_id_map = first_model_config.get("speaker_id_map")
205
+ if speaker_id_map:
206
+ selected_speaker_id = next(iter(speaker_id_map)) # Get the first speaker_id
207
+
208
+ return selected_model, selected_speaker_id
209
 
210
 
211
 
 
214
  request: Request,
215
  text_input: str = Form(default="1, 2, 3. This is a test. Enter some text to generate."),
216
  selected_model: str = Form(...), # Selected model
217
+ selected_speaker_id: str = Form(...), # Selected speaker ID
218
  speaker: str = Form(...),
219
  speed_slider: float = Form(...),
220
  noise_scale_slider: float = Form(...),
221
  noise_scale_w_slider: float = Form(...),
222
+ play: bool = Form(True),
223
+ initial_values: Tuple[str, str] = Depends(get_initial_values) # Use the dependency here
224
  ):
225
  # ... (previous code)
226
 
 
270
  "file_url": file_url,
271
  "text_input": text_input,
272
  "data": data,
273
+ # "model_names": onnx_models,
274
  "selected_model": selected_model,
275
+ "model_names": model_configurations.keys(),
276
+ "initial_selected_model": initial_selected_model,
277
+ "initial_selected_speaker_id": initial_selected_speaker_id,
278
  "speaker_id_map": speaker_id_map, # Make sure speaker_id_map is included here
279
  "speaker_select": speaker_select,
280
  "dynamic_content": response_html