sanchit-gandhi commited on
Commit
c0bc0f2
·
1 Parent(s): 0161d23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -10,13 +10,18 @@ import gradio as gr
10
 
11
  model_id = "openai/whisper-large-v2"
12
 
 
 
13
  processor = WhisperProcessor.from_pretrained(model_id)
14
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
 
15
 
16
  sampling_rate = processor.feature_extractor.sampling_rate
17
 
18
  bos_token_id = processor.tokenizer.all_special_ids[-106]
19
  decoder_input_ids = torch.tensor([bos_token_id])
 
20
 
21
 
22
  def process_audio_file(file):
@@ -47,7 +52,7 @@ def transcribe(Microphone, File_Upload):
47
  input_features = processor(audio_data, return_tensors="pt").input_features
48
 
49
  with torch.no_grad():
50
- logits = model.forward(input_features, decoder_input_ids=decoder_input_ids).logits
51
 
52
  pred_ids = torch.argmax(logits, dim=-1)
53
  probability = F.softmax(logits, dim=-1).max()
 
10
 
11
  model_id = "openai/whisper-large-v2"
12
 
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
  processor = WhisperProcessor.from_pretrained(model_id)
16
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
17
+ model.eval()
18
+ model.to(device)
19
 
20
  sampling_rate = processor.feature_extractor.sampling_rate
21
 
22
  bos_token_id = processor.tokenizer.all_special_ids[-106]
23
  decoder_input_ids = torch.tensor([bos_token_id])
24
+ decoder_input_ids.to(device)
25
 
26
 
27
  def process_audio_file(file):
 
52
  input_features = processor(audio_data, return_tensors="pt").input_features
53
 
54
  with torch.no_grad():
55
+ logits = model.forward(input_features.to(device), decoder_input_ids=decoder_input_ids).logits
56
 
57
  pred_ids = torch.argmax(logits, dim=-1)
58
  probability = F.softmax(logits, dim=-1).max()