sagar007 commited on
Commit
6b1a045
1 Parent(s): 1f0b302

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -79,7 +79,6 @@ async def generate_speech(text, tts_model, tts_tokenizer):
79
 
80
  return audio_generation.cpu().numpy().squeeze()
81
 
82
- # Helper functions
83
  @spaces.GPU(timeout=300)
84
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
85
  try:
@@ -111,22 +110,24 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
111
  thread.start()
112
 
113
  buffer = ""
114
- audio_buffer = np.array([])
115
 
116
  for new_text in streamer:
117
  buffer += new_text
118
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
119
 
120
  # Generate speech after text generation is complete
121
- if use_tts:
122
  audio_buffer = generate_speech_sync(buffer, tts_model, tts_tokenizer)
 
 
123
 
124
  # Final yield with complete text and audio
125
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
126
 
127
  except Exception as e:
128
  print(f"An error occurred: {str(e)}")
129
- yield history + [[message, f"An error occurred: {str(e)}"]], None
130
 
131
  def generate_speech_sync(text, tts_model, tts_tokenizer):
132
  tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
@@ -136,7 +137,8 @@ def generate_speech_sync(text, tts_model, tts_tokenizer):
136
  with torch.no_grad():
137
  audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
138
 
139
- return audio_generation.cpu().numpy().squeeze()
 
140
 
141
  @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
142
  def process_vision_query(image, text_input):
 
79
 
80
  return audio_generation.cpu().numpy().squeeze()
81
 
 
82
  @spaces.GPU(timeout=300)
83
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
84
  try:
 
110
  thread.start()
111
 
112
  buffer = ""
113
+ audio_buffer = np.array([0.0]) # Initialize with a single zero
114
 
115
  for new_text in streamer:
116
  buffer += new_text
117
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
118
 
119
  # Generate speech after text generation is complete
120
+ if use_tts and buffer: # Only generate speech if there's text
121
  audio_buffer = generate_speech_sync(buffer, tts_model, tts_tokenizer)
122
+ if audio_buffer.size == 0: # If audio_buffer is empty
123
+ audio_buffer = np.array([0.0]) # Use a single zero instead
124
 
125
  # Final yield with complete text and audio
126
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
127
 
128
  except Exception as e:
129
  print(f"An error occurred: {str(e)}")
130
+ yield history + [[message, f"An error occurred: {str(e)}"]], (tts_model.config.sampling_rate, np.array([0.0]))
131
 
132
  def generate_speech_sync(text, tts_model, tts_tokenizer):
133
  tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
 
137
  with torch.no_grad():
138
  audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
139
 
140
+ audio_buffer = audio_generation.cpu().numpy().squeeze()
141
+ return audio_buffer if audio_buffer.size > 0 else np.array([0.0])
142
 
143
  @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
144
  def process_vision_query(image, text_input):