aiola commited on
Commit
2d4fde9
·
verified ·
1 Parent(s): f6ac33a

Update app with tag and masking model

Browse files
Files changed (1) hide show
  1. app.py +44 -33
app.py CHANGED
@@ -9,39 +9,40 @@ import re
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # Load model and processor
12
- processor = WhisperProcessor.from_pretrained("aiola/whisper-ner-v1")
13
- model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1")
14
  model = model.to(device)
15
 
16
-
17
  examples = [
18
  [
19
  "audio/sports.wav",
20
- "football-club, football-player, action"
 
21
  ],
22
  [
23
  "audio/entertainment.wav",
24
- "movie, date, actor, tv-show, musician"
 
25
  ],
26
  [
27
  "audio/672-122797-0026.wav",
28
- "biological-classification, desire, demographic-group, object-category, relationship-role, reflexive-pronoun, furniture-type"
 
29
  ],
30
- [
31
- "audio/7021-85628-0025.wav",
32
- "action-goal, person's-title, emotional-connection, personal-qualities, pronoun-target, assignmentaction, physical-action, family-role"
 
33
  ],
34
  [
35
  "audio/672-122797-0024.wav",
36
- "health-warning, importance-indicator, event, sentiment"
37
- ],
38
- [
39
- "audio/672-122797-0027.wav",
40
- "action, emotional-resilience, comparative-path-characteristic, social-role"
41
  ],
42
  [
43
  "audio/672-122797-0048.wav",
44
- "weapon, emotional-state, household-chore, atmosphere-quality"
 
45
  ],
46
  ]
47
 
@@ -54,8 +55,8 @@ def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")):
54
  return text.lower()
55
 
56
 
57
- def extract_entities_and_clean_text_fixed(text):
58
- entity_pattern = r"<(.*?)>(.*?)<\1>>"
59
  entities = []
60
  clean_text = []
61
  current_pos = 0
@@ -66,7 +67,7 @@ def extract_entities_and_clean_text_fixed(text):
66
  clean_text.append(text[current_pos:match.start()])
67
 
68
  entity_type = match.group(1)
69
- entity_text = match.group(2)
70
  start_pos = len("".join(clean_text)) # Start position in the clean text
71
  end_pos = start_pos + len(entity_text)
72
 
@@ -94,7 +95,7 @@ def extract_entities_and_clean_text_fixed(text):
94
 
95
 
96
  @spaces.GPU # This decorator ensures your function can use GPU on Hugging Face Spaces
97
- def transcribe_and_recognize_entities(audio_file, prompt):
98
  target_sample_rate = 16000
99
  signal, sampling_rate = torchaudio.load(audio_file)
100
  resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate)
@@ -108,6 +109,8 @@ def transcribe_and_recognize_entities(audio_file, prompt):
108
  ner_types = prompt.split(',')
109
  processed_ner_types = [unify_ner_text(ner_type.strip()) for ner_type in ner_types]
110
  prompt = ", ".join(processed_ner_types)
 
 
111
 
112
  print(f"Prompt after unify_ner_text: {prompt}")
113
  prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
@@ -122,36 +125,44 @@ def transcribe_and_recognize_entities(audio_file, prompt):
122
  )
123
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
124
 
125
- clean_text_fixed, extracted_entities_fixed = extract_entities_and_clean_text_fixed(transcription)
126
 
127
  return transcription, {"text": clean_text_fixed, "entities": extracted_entities_fixed}
128
 
129
 
130
  with gr.Blocks(title="WhisperNER v1") as demo:
131
-
132
  gr.Markdown(
133
  """
134
- # Whisper-NER: ASR with zero-shot NER
135
-
136
  WhisperNER is a unified model for automatic speech recognition (ASR) and named entity recognition (NER), with zero-shot capabilities.
137
  The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance.
 
 
 
 
 
138
 
139
  ## Links
140
-
141
- * Paper: [WhisperNER: Unified Open Named Entity and Speech Recognition](https://arxiv.org/abs/2409.08107).
142
- * Model: https://huggingface.co/aiola/whisper-ner-v1
143
- * Code: https://github.com/aiola-lab/whisper-ner
144
  """
145
  )
146
 
147
  with gr.Row() as row1:
148
  with gr.Column() as col1:
149
- audio_input = gr.Audio(label="Audio Example", type="filepath")
150
  with gr.Column() as col2:
151
- label_input = gr.Textbox(label="Entity Labels")
 
 
 
 
 
 
152
 
153
  submit_btn = gr.Button("Submit")
154
-
155
  gr.Markdown("## Output")
156
 
157
  with gr.Row() as row3:
@@ -163,7 +174,7 @@ with gr.Blocks(title="WhisperNER v1") as demo:
163
  examples = gr.Examples(
164
  examples,
165
  fn=transcribe_and_recognize_entities,
166
- inputs=[audio_input, label_input],
167
  outputs=[transcript_output, highlighted_text_output],
168
  cache_examples=True,
169
  run_on_click=True,
@@ -172,12 +183,12 @@ with gr.Blocks(title="WhisperNER v1") as demo:
172
  # Submitting
173
  label_input.submit(
174
  fn=transcribe_and_recognize_entities,
175
- inputs=[audio_input, label_input],
176
  outputs=[transcript_output, highlighted_text_output],
177
  )
178
  submit_btn.click(
179
  fn=transcribe_and_recognize_entities,
180
- inputs=[audio_input, label_input],
181
  outputs=[transcript_output, highlighted_text_output],
182
  )
183
 
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # Load model and processor
12
+ processor = WhisperProcessor.from_pretrained("aiola/whisper-ner-tag-and-mask-v1")
13
+ model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-tag-and-mask-v1")
14
  model = model.to(device)
15
 
 
16
  examples = [
17
  [
18
  "audio/sports.wav",
19
+ "football-club, football-player, referee",
20
+ False
21
  ],
22
  [
23
  "audio/entertainment.wav",
24
+ "movie, date, actor, tv-show, musician",
25
+ True
26
  ],
27
  [
28
  "audio/672-122797-0026.wav",
29
+ "biological-classification, desire, demographic-group, object-category, relationship-role, reflexive-pronoun, furniture-type",
30
+ False
31
  ],
32
+ [
33
+ "audio/672-122797-0027.wav",
34
+ "action, emotional-resilience, comparative-path-characteristic, social-role",
35
+ True
36
  ],
37
  [
38
  "audio/672-122797-0024.wav",
39
+ "health-warning, importance-indicator, event, sentiment",
40
+ False
 
 
 
41
  ],
42
  [
43
  "audio/672-122797-0048.wav",
44
+ "weapon, emotional-state, household-chore, atmosphere-quality",
45
+ False
46
  ],
47
  ]
48
 
 
55
  return text.lower()
56
 
57
 
58
+ def extract_entities_and_clean_text_fixed(text, ner_mask=False):
59
+ entity_pattern = r"<(.*?)>(.*?)<\1>>" if not ner_mask else r"<(.*?)>>"
60
  entities = []
61
  clean_text = []
62
  current_pos = 0
 
67
  clean_text.append(text[current_pos:match.start()])
68
 
69
  entity_type = match.group(1)
70
+ entity_text = "-" if ner_mask else match.group(2)
71
  start_pos = len("".join(clean_text)) # Start position in the clean text
72
  end_pos = start_pos + len(entity_text)
73
 
 
95
 
96
 
97
  @spaces.GPU # This decorator ensures your function can use GPU on Hugging Face Spaces
98
+ def transcribe_and_recognize_entities(audio_file, prompt, ner_mask=False):
99
  target_sample_rate = 16000
100
  signal, sampling_rate = torchaudio.load(audio_file)
101
  resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate)
 
109
  ner_types = prompt.split(',')
110
  processed_ner_types = [unify_ner_text(ner_type.strip()) for ner_type in ner_types]
111
  prompt = ", ".join(processed_ner_types)
112
+ if ner_mask:
113
+ prompt = f"<|mask|>{prompt}"
114
 
115
  print(f"Prompt after unify_ner_text: {prompt}")
116
  prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
 
125
  )
126
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
127
 
128
+ clean_text_fixed, extracted_entities_fixed = extract_entities_and_clean_text_fixed(transcription, ner_mask=ner_mask)
129
 
130
  return transcription, {"text": clean_text_fixed, "entities": extracted_entities_fixed}
131
 
132
 
133
  with gr.Blocks(title="WhisperNER v1") as demo:
 
134
  gr.Markdown(
135
  """
136
+ # 🔥 Whisper-NER: ASR with zero-shot NER
 
137
  WhisperNER is a unified model for automatic speech recognition (ASR) and named entity recognition (NER), with zero-shot capabilities.
138
  The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance.
139
+
140
+ The [aiola/whisper-ner-tag-and-mask-v1](https://huggingface.co/aiola/whisper-ner-tag-and-mask-v1) model was finetuned from
141
+ the [aiola/whisper-ner-v1](https://huggingface.co/aiola/whisper-ner-v1) checkpoint using the NuNER dataset to perform joint audio transcription and NER tagging or NER masking.
142
+ The model was not trained on PII specific datasets, hence can perform general and open type entity masking.
143
+ It should be further funetuned in order to be used for PII detection. The model was trained and evaluated only on English data. Check out the paper for full details.
144
 
145
  ## Links
146
+ * 📄 Paper: [WhisperNER: Unified Open Named Entity and Speech Recognition](https://arxiv.org/abs/2409.08107)
147
+ * 🤗 [WhisperNER model collection](https://huggingface.co/collections/aiola/whisperner-6723f14506f3662cf3a73df2)
148
+ * 💻 Code: https://github.com/aiola-lab/whisper-ner
 
149
  """
150
  )
151
 
152
  with gr.Row() as row1:
153
  with gr.Column() as col1:
154
+ audio_input = gr.Audio(value=examples[0][0], label="Audio Example", type="filepath")
155
  with gr.Column() as col2:
156
+ label_input = gr.Textbox(label="Entity Labels", value=examples[0][1])
157
+ ner_mask = gr.Checkbox(
158
+ value=examples[0][2],
159
+ label="Entity Mask",
160
+ info="Mask or tag entities in the transcription.",
161
+ scale=0,
162
+ )
163
 
164
  submit_btn = gr.Button("Submit")
165
+
166
  gr.Markdown("## Output")
167
 
168
  with gr.Row() as row3:
 
174
  examples = gr.Examples(
175
  examples,
176
  fn=transcribe_and_recognize_entities,
177
+ inputs=[audio_input, label_input, ner_mask],
178
  outputs=[transcript_output, highlighted_text_output],
179
  cache_examples=True,
180
  run_on_click=True,
 
183
  # Submitting
184
  label_input.submit(
185
  fn=transcribe_and_recognize_entities,
186
+ inputs=[audio_input, label_input, ner_mask],
187
  outputs=[transcript_output, highlighted_text_output],
188
  )
189
  submit_btn.click(
190
  fn=transcribe_and_recognize_entities,
191
+ inputs=[audio_input, label_input, ner_mask],
192
  outputs=[transcript_output, highlighted_text_output],
193
  )
194