hubert-large-asr

This model is a fine-tuned version of rinna/japanese-hubert-large ASR. Initially fine-tuned on the reazonspeech(small) dataset, it was subsequently further fine-tuned on the common_voice_11_0 dataset for ASR tasks.

This model can only predict Hiragana.

Acknowledgments

This model's fine-tuning approach was inspired by and references the training methodology used in vumichien/wav2vec2-large-xlsr-japanese-hiragana.

Training procedure

The model was fine-tuned in two main stages, first on the Reazonspeech dataset, followed by the common_voice_11_0 dataset. Details of the training steps and results are as follows:

Training on Reazonspeech

The initial fine-tuning on the Reazonspeech(small) dataset was carried out with the following performance metrics:

Step Training Loss Validation Loss WER
1000 12.29880 3.610288 1.00000
2000 3.601800 3.505306 1.00000
3000 2.80300 1.948012 0.722361
4000 1.961500 1.545842 0.558738
5000 1.712000 1.420027 0.509049
6000 1.565500 1.235171 0.466279
7000 1.504900 1.160565 0.461829
8000 1.409800 1.088012 0.427435
9000 1.358800 1.097211 0.409861
10000 1.318600 1.062294 0.403694
11000 1.258500 1.026783 0.385464
12000 1.245100 1.024860 0.379845
13000 1.217700 0.985201 0.375634
14000 1.187900 0.977686 0.367163
15000 1.168100 0.978529 0.363656
16000 1.135800 0.965668 0.363942
17000 1.140600 0.953237 0.360912

Training on common_voice_11_0

After fine-tuning on Reazonspeech, further fine-tuning was performed on the common_voice_11_0 dataset, leading to the following results:

Step Training Loss Validation Loss WER
1000 1.08950 0.49275 0.302035
2000 0.86100 0.45113 0.266950
3000 0.76240 0.442281 0.244981
4000 0.70170 0.411666 0.234287
5000 0.66400 0.411769 0.227942
6000 0.63810 0.413067 0.225690

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 1e-4
  • train_batch_size: 8
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 2
  • num_train_epochs: 10
  • lr_scheduler_type: linear

How to evaluate the model

from transformers import HubertForCTC, Wav2Vec2Processor
from datasets import load_dataset
import torch
import torchaudio
import librosa
import numpy as np
import re
import MeCab
import pykakasi
from evaluate import load

model = HubertForCTC.from_pretrained('TKU410410103/hubert-large-japanese-asr')
processor = Wav2Vec2Processor.from_pretrained("TKU410410103/hubert-large-japanese-asr")

# load dataset
test_dataset = load_dataset('mozilla-foundation/common_voice_11_0', 'ja', split='test')
remove_columns = [col for col in test_dataset.column_names if col not in ['audio', 'sentence']]
test_dataset = test_dataset.remove_columns(remove_columns)

# resample
def process_waveforms(batch):
    speech_arrays = []
    sampling_rates = []

    for audio_path in batch['audio']:
        speech_array, _ = torchaudio.load(audio_path['path'])
        speech_array_resampled = librosa.resample(np.asarray(speech_array[0].numpy()), orig_sr=48000, target_sr=16000)
        speech_arrays.append(speech_array_resampled)
        sampling_rates.append(16000)

    batch["array"] = speech_arrays
    batch["sampling_rate"] = sampling_rates

    return batch

# hiragana
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
          "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
          "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
          "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
          "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "'", "ʻ", "ˆ"]
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"

wakati = MeCab.Tagger("-Owakati")
kakasi = pykakasi.kakasi()
kakasi.setMode("J","H")
kakasi.setMode("K","H")
kakasi.setMode("r","Hepburn")
conv = kakasi.getConverter()

def prepare_char(batch):
    batch["sentence"] = conv.do(wakati.parse(batch["sentence"]).strip())
    batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
    return batch


resampled_eval_dataset = test_dataset.map(process_waveforms, batched=True, batch_size=50, num_proc=4)
eval_dataset = resampled_eval_dataset.map(prepare_char, num_proc=4)

# begin the evaluation process
wer = load("wer")
cer = load("cer")

def evaluate(batch):
    inputs = processor(batch["array"], sampling_rate=16_000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.to(device), attention_mask=inputs.attention_mask.to(device)).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch

columns_to_remove = [column for column in eval_dataset.column_names if column != "sentence"]
batch_size = 16
result = eval_dataset.map(evaluate, remove_columns=columns_to_remove, batched=True, batch_size=batch_size)

wer_result = wer.compute(predictions=result["pred_strings"], references=result["sentence"])
cer_result = cer.compute(predictions=result["pred_strings"], references=result["sentence"])

print("WER: {:2f}%".format(100 * wer_result))
print("CER: {:2f}%".format(100 * cer_result))

Test results

The final model was evaluated as follows:

On reazonspeech(tiny):

  • WER: 40.519700%
  • CER: 23.220979%

On common_voice_11_0:

  • WER: 22.705487%
  • CER: 9.399390%

Framework versions

  • Transformers 4.39.1
  • Pytorch 2.2.1+cu118
  • Datasets 2.17.1
Downloads last month
217
Safetensors
Model size
316M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train TKU410410103/hubert-large-japanese-asr

Evaluation results