anuragshas commited on
Commit
dc167bf
Β·
1 Parent(s): 37475e4

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +49 -15
eval.py CHANGED
@@ -4,6 +4,7 @@ import re
4
  import unicodedata
5
  from typing import Dict
6
 
 
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
  from transformers import AutoFeatureExtractor, pipeline
@@ -20,8 +21,12 @@ def log_results(result: Dataset, args: Dict[str, str]):
20
  cer = load_metric("cer")
21
 
22
  # compute metrics
23
- wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
24
- cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
 
 
 
 
25
 
26
  # print & log results
27
  result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
@@ -50,10 +55,9 @@ def log_results(result: Dataset, args: Dict[str, str]):
50
  def normalize_text(text: str) -> str:
51
  """DO ADAPT FOR YOUR USE CASE. this function normalizes the target text."""
52
 
53
- chars_to_ignore_regex = '''[\ΰ₯€\!\"\,\-\.\?\:\|\β€œ\”\–\;\'\’\β€˜\ΰ₯”]''' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
54
-
55
- text = re.sub(chars_to_ignore_regex, "", text.lower())
56
  text = unicodedata.normalize("NFKC", text)
 
57
 
58
  # In addition, we can normalize the target text, e.g. removing new lines characters etc...
59
  # note that order is important here!
@@ -67,7 +71,9 @@ def normalize_text(text: str) -> str:
67
 
68
  def main(args):
69
  # load dataset
70
- dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
 
 
71
 
72
  # for testing: only process the first two examples as a test
73
  # dataset = dataset.select(range(10))
@@ -80,12 +86,18 @@ def main(args):
80
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
81
 
82
  # load eval pipeline
83
- asr = pipeline("automatic-speech-recognition", model=args.model_id, device=0)
 
 
 
 
84
 
85
  # map function to decode audio
86
  def map_to_pred(batch):
87
  prediction = asr(
88
- batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
 
 
89
  )
90
 
91
  batch["prediction"] = prediction["text"]
@@ -104,7 +116,10 @@ if __name__ == "__main__":
104
  parser = argparse.ArgumentParser()
105
 
106
  parser.add_argument(
107
- "--model_id", type=str, required=True, help="Model identifier. Should be loadable with πŸ€— Transformers"
 
 
 
108
  )
109
  parser.add_argument(
110
  "--dataset",
@@ -113,18 +128,37 @@ if __name__ == "__main__":
113
  help="Dataset name to evaluate the `model_id`. Should be loadable with πŸ€— Datasets",
114
  )
115
  parser.add_argument(
116
- "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
- parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`")
119
  parser.add_argument(
120
- "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to 5 seconds."
 
 
 
121
  )
122
  parser.add_argument(
123
- "--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to 1 second."
 
 
124
  )
125
  parser.add_argument(
126
- "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
 
 
 
127
  )
128
  args = parser.parse_args()
129
 
130
- main(args)
 
4
  import unicodedata
5
  from typing import Dict
6
 
7
+ import torch
8
  from datasets import Audio, Dataset, load_dataset, load_metric
9
 
10
  from transformers import AutoFeatureExtractor, pipeline
 
21
  cer = load_metric("cer")
22
 
23
  # compute metrics
24
+ wer_result = wer.compute(
25
+ references=result["target"], predictions=result["prediction"]
26
+ )
27
+ cer_result = cer.compute(
28
+ references=result["target"], predictions=result["prediction"]
29
+ )
30
 
31
  # print & log results
32
  result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
 
55
  def normalize_text(text: str) -> str:
56
  """DO ADAPT FOR YOUR USE CASE. this function normalizes the target text."""
57
 
58
+ chars_to_ignore_regex = """[\,\?\.\!\-\;\:\"\β€œ\%\β€˜\”\οΏ½\β€”\’\…\–\'\ΰ₯€\ΰ₯”]""" # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
 
 
59
  text = unicodedata.normalize("NFKC", text)
60
+ text = re.sub(chars_to_ignore_regex, "", text.lower())
61
 
62
  # In addition, we can normalize the target text, e.g. removing new lines characters etc...
63
  # note that order is important here!
 
71
 
72
  def main(args):
73
  # load dataset
74
+ dataset = load_dataset(
75
+ args.dataset, args.config, split=args.split, use_auth_token=True
76
+ )
77
 
78
  # for testing: only process the first two examples as a test
79
  # dataset = dataset.select(range(10))
 
86
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
87
 
88
  # load eval pipeline
89
+ if args.device is None:
90
+ args.device = 0 if torch.cuda.is_available() else -1
91
+ asr = pipeline(
92
+ "automatic-speech-recognition", model=args.model_id, device=args.device
93
+ )
94
 
95
  # map function to decode audio
96
  def map_to_pred(batch):
97
  prediction = asr(
98
+ batch["audio"]["array"],
99
+ chunk_length_s=args.chunk_length_s,
100
+ stride_length_s=args.stride_length_s,
101
  )
102
 
103
  batch["prediction"] = prediction["text"]
 
116
  parser = argparse.ArgumentParser()
117
 
118
  parser.add_argument(
119
+ "--model_id",
120
+ type=str,
121
+ required=True,
122
+ help="Model identifier. Should be loadable with πŸ€— Transformers",
123
  )
124
  parser.add_argument(
125
  "--dataset",
 
128
  help="Dataset name to evaluate the `model_id`. Should be loadable with πŸ€— Datasets",
129
  )
130
  parser.add_argument(
131
+ "--config",
132
+ type=str,
133
+ required=True,
134
+ help="Config of the dataset. *E.g.* `'en'` for Common Voice",
135
+ )
136
+ parser.add_argument(
137
+ "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
138
+ )
139
+ parser.add_argument(
140
+ "--chunk_length_s",
141
+ type=float,
142
+ default=None,
143
+ help="Chunk length in seconds. Defaults to 5 seconds.",
144
  )
 
145
  parser.add_argument(
146
+ "--stride_length_s",
147
+ type=float,
148
+ default=None,
149
+ help="Stride of the audio chunks. Defaults to 1 second.",
150
  )
151
  parser.add_argument(
152
+ "--log_outputs",
153
+ action="store_true",
154
+ help="If defined, write outputs to log file for analysis.",
155
  )
156
  parser.add_argument(
157
+ "--device",
158
+ type=int,
159
+ default=None,
160
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
161
  )
162
  args = parser.parse_args()
163
 
164
+ main(args)