Yurii Paniv commited on
Commit
b75a2aa
·
1 Parent(s): f9e5028

Fix memory leak, remove gradient storage

Browse files
Files changed (2) hide show
  1. .gitignore +4 -1
  2. app.py +13 -10
.gitignore CHANGED
@@ -130,4 +130,7 @@ dmypy.json
130
 
131
  # model files
132
  *.pth.tar
133
- *.pth
 
 
 
 
130
 
131
  # model files
132
  *.pth.tar
133
+ *.pth
134
+
135
+ # gradio
136
+ gradio_queue.db
app.py CHANGED
@@ -10,6 +10,8 @@ from formatter import preprocess_text
10
  from datetime import datetime
11
  from stress import sentence_to_stress
12
  from enum import Enum
 
 
13
 
14
  class StressOption(Enum):
15
  ManualStress = "Наголоси вручну"
@@ -50,21 +52,22 @@ for MODEL_NAME in MODEL_NAMES:
50
 
51
 
52
  def tts(text: str, stress: str):
53
- synthesizer = Synthesizer(
54
- model_path, config_path, None, None, None,
55
- )
56
  text = preprocess_text(text)
57
- text_limit = 150
58
  text = text if len(text) < text_limit else text[0:text_limit] # mitigate crashes on hf space
59
  text = sentence_to_stress(text) if stress == StressOption.AutomaticStress.value else text
60
  print(text, datetime.utcnow())
61
- if synthesizer is None:
62
- raise NameError("model not found")
63
- wavs = synthesizer.tts(text)
64
- # output = (synthesizer.output_sample_rate, np.array(wavs))
65
- # return output
66
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
67
- synthesizer.save_wav(wavs, fp)
 
 
 
 
 
 
 
 
68
  return fp.name
69
 
70
 
 
10
  from datetime import datetime
11
  from stress import sentence_to_stress
12
  from enum import Enum
13
+ import torch
14
+ import gc
15
 
16
  class StressOption(Enum):
17
  ManualStress = "Наголоси вручну"
 
52
 
53
 
54
  def tts(text: str, stress: str):
 
 
 
55
  text = preprocess_text(text)
56
+ text_limit = 1200
57
  text = text if len(text) < text_limit else text[0:text_limit] # mitigate crashes on hf space
58
  text = sentence_to_stress(text) if stress == StressOption.AutomaticStress.value else text
59
  print(text, datetime.utcnow())
60
+
 
 
 
 
61
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
62
+ with torch.no_grad():
63
+ synthesizer = Synthesizer(
64
+ model_path, config_path, None, None, None,
65
+ )
66
+ if synthesizer is None:
67
+ raise NameError("model not found")
68
+ wavs = synthesizer.tts(text)
69
+ synthesizer.save_wav(wavs, fp)
70
+ gc.collect()
71
  return fp.name
72
 
73