Spaces:
Runtime error
Runtime error
import json | |
import os | |
import warnings | |
import gradio as gr | |
import librosa | |
import numpy as np | |
from datasets import IterableDatasetDict, load_dataset | |
from gradio_client import Client | |
from loguru import logger | |
warnings.filterwarnings("ignore") | |
NUM_TAR_FILES = 115 | |
HF_PATH_TO_DATASET = "litagin/Galgame_Speech_SER_16kHz" | |
hf_token = os.getenv("HF_TOKEN") | |
client = Client("litagin/ser_record", hf_token=hf_token) | |
id2label = { | |
0: "Angry", | |
1: "Disgusted", | |
2: "Embarrassed", | |
3: "Fearful", | |
4: "Happy", | |
5: "Sad", | |
6: "Surprised", | |
7: "Neutral", | |
8: "Sexual1", | |
9: "Sexual2", | |
} | |
id2rich_label = { | |
0: "😠 怒り (0)", | |
1: "😒 嫌悪 (1)", | |
2: "😳 恥ずかしさ・戸惑い (2)", | |
3: "😨 恐怖 (3)", | |
4: "😊 幸せ (4)", | |
5: "😢 悲しみ (5)", | |
6: "😲 驚き (6)", | |
7: "😐 中立 (7)", | |
8: "🥰 NSFW1 (8)", | |
9: "🍭 NSFW2 (9)", | |
} | |
current_item: dict | None = None | |
def _load_dataset( | |
*, | |
streaming: bool = True, | |
use_local_dataset: bool = False, | |
local_dataset_path: str | None = None, | |
data_dir: str = "data", | |
) -> IterableDatasetDict: | |
data_files = { | |
"train": [ | |
f"galgame-speech-ser-16kHz-train-000{index:03d}.tar" | |
for index in range(0, NUM_TAR_FILES) | |
], | |
} | |
if use_local_dataset: | |
assert local_dataset_path is not None | |
path = local_dataset_path | |
else: | |
path = HF_PATH_TO_DATASET | |
dataset: IterableDatasetDict = load_dataset( | |
path=path, data_dir=data_dir, data_files=data_files, streaming=streaming | |
) # type: ignore | |
dataset = dataset.remove_columns(["__url__"]) | |
dataset = dataset.rename_column("ogg", "audio") | |
return dataset | |
logger.info("Start loading dataset") | |
ds = _load_dataset(streaming=True, use_local_dataset=False) | |
logger.info("Dataset loaded") | |
# seed = random.randint(0, 2**32 - 1) | |
# logger.info(f"Seed: {seed}") | |
# ds_iter = iter(ds["train"].shuffle(seed=seed)) | |
ds_iter = iter(ds["train"]) | |
shortcut_js = """ | |
<script> | |
function shortcuts(e) { | |
if (e.key === "Enter") { | |
document.getElementById("btn_skip").click(); | |
} else if (e.key === "0") { | |
document.getElementById("btn_0").click(); | |
} else if (e.key === "1") { | |
document.getElementById("btn_1").click(); | |
} else if (e.key === "2") { | |
document.getElementById("btn_2").click(); | |
} else if (e.key === "3") { | |
document.getElementById("btn_3").click(); | |
} else if (e.key === "4") { | |
document.getElementById("btn_4").click(); | |
} else if (e.key === "5") { | |
document.getElementById("btn_5").click(); | |
} else if (e.key === "6") { | |
document.getElementById("btn_6").click(); | |
} else if (e.key === "7") { | |
document.getElementById("btn_7").click(); | |
} else if (e.key === "8") { | |
document.getElementById("btn_8").click(); | |
} else if (e.key === "9") { | |
document.getElementById("btn_9").click(); | |
} | |
} | |
document.addEventListener('keypress', shortcuts, false); | |
</script> | |
""" | |
def modify_speed( | |
data: tuple[int, np.ndarray], speed: float = 1.0 | |
) -> tuple[int, np.ndarray]: | |
if speed == 1.0: | |
return data | |
sr, array = data | |
return sr, librosa.effects.time_stretch(array, rate=speed) | |
def parse_item(item, speed: float = 1.0) -> dict: | |
label_id = item["cls"] | |
sampling_rate = item["audio"]["sampling_rate"] | |
array = item["audio"]["array"] | |
return { | |
"key": item["__key__"], | |
"audio": (sampling_rate, array), | |
"text": item["txt"], | |
"label": id2rich_label[label_id], | |
"label_id": label_id, | |
} | |
def get_next_parsed_item(speed: float = 1.0) -> dict: | |
logger.info("Getting next item") | |
next_item = next(ds_iter) | |
parsed = parse_item(next_item, speed=speed) | |
logger.info( | |
f"Next item:\nkey={parsed['key']}\ntext={parsed['text']}\nlabel={parsed['label']}" | |
) | |
return parsed | |
md = """ | |
# 説明 | |
- このアプリは、ゲームのセリフを感情ラベル付けして、大規模な感情音声データセットを作成するためのものです | |
- **性的な音声が含まれるため、18歳未満の方はご利用をお控えください** | |
- 既存のラベルが適切であれば、そのまま「現在の感情ラベルで適切」ボタンを押してください | |
- ラベルを修正する場合は、適切なボタンを押してください | |
- ショートカットキー(カッコ内)を使うこともできます | |
# 補足 | |
- `🥰 NSFW1` は女性の性的行為中の音声(喘ぎ声等) | |
- `🍭 NSFW2` はキスシーンでのリップ音やフェラシーンでのしゃぶる音(チュパ音)を表します | |
- 感情が音声からは特に読み取れない場合は `😐 中立` を選択してください | |
""" | |
with gr.Blocks(head=shortcut_js) as app: | |
gr.Markdown(md) | |
with gr.Row(): | |
with gr.Column(): | |
btn_init = gr.Button("初期化・再読み込み") | |
speed = gr.Slider( | |
minimum=0.5, maximum=5.0, step=0.1, value=1.0, label="再生速度" | |
) | |
with gr.Column(variant="panel"): | |
key = gr.Textbox(label="Key") | |
audio = gr.Audio() | |
text = gr.Textbox(label="Text") | |
label = gr.Textbox(label="感情ラベル") | |
label_id = gr.Textbox(visible=False) | |
btn_skip = gr.Button("現在の感情ラベルで適切 (Enter)", elem_id="btn_skip") | |
with gr.Column(): | |
gr.Markdown("# 感情ラベルを修正する場合") | |
btn_list = [ | |
gr.Button(id2rich_label[_id], elem_id=f"btn_{_id}") for _id in range(10) | |
] | |
def update_current_item(data: dict) -> dict: | |
global current_item | |
if current_item is None: | |
speed_value = data[speed] | |
current_item = get_next_parsed_item(speed=speed_value) | |
modified_audio = modify_speed(current_item["audio"], speed=data[speed]) | |
return { | |
key: current_item["key"], | |
audio: gr.Audio(modified_audio, autoplay=True), | |
text: current_item["text"], | |
label: current_item["label"], | |
label_id: current_item["label_id"], | |
} | |
def set_next_item(data: dict) -> dict: | |
global current_item | |
speed_value = data[speed] | |
current_item = get_next_parsed_item(speed=speed_value) | |
return update_current_item(data) | |
def put_unmodified(data: dict) -> dict: | |
logger.info("Putting unmodified") | |
current_key = data[key] | |
current_label_id = data[label_id] | |
_ = client.predict( | |
new_data=json.dumps( | |
{ | |
"key": current_key, | |
"cls": int(current_label_id), | |
} | |
), | |
api_name="/put_data", | |
) | |
logger.info("Unmodified sent") | |
return set_next_item(data) | |
btn_init.click( | |
update_current_item, inputs={speed}, outputs=[key, audio, text, label, label_id] | |
) | |
btn_skip.click( | |
put_unmodified, | |
inputs={key, label_id, speed}, | |
outputs=[key, audio, text, label, label_id], | |
) | |
functions_list = [] | |
for _id in range(10): | |
def put_label(data: dict, _id=_id) -> dict: | |
logger.info(f"Putting label: {id2rich_label[_id]}") | |
current_key = data[key] | |
_ = client.predict( | |
new_data=json.dumps( | |
{ | |
"key": current_key, | |
"cls": _id, | |
} | |
), | |
api_name="/put_data", | |
) | |
logger.info("Modified sent") | |
return set_next_item(data) | |
functions_list.append(put_label) | |
for _id in range(10): | |
btn_list[_id].click( | |
functions_list[_id], | |
inputs={key, speed}, | |
outputs=[key, audio, text, label, label_id], | |
) | |
app.launch() | |