Spaces:
Runtime error
Runtime error
import math | |
import os | |
import argparse | |
import sqlite3 | |
import shutil | |
import uuid | |
from datasets import Dataset, concatenate_datasets | |
import gradio as gr | |
import torch | |
from storing.createdb import create_db | |
from preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor | |
from loading.serialization import JsonSerializer | |
from utils import nest_list, is_google_colab | |
from datapipeline import create_hardcoded_data_pipeline | |
from threadeddatapipeline import ThreadedDataPipeline | |
from dataset.hf_dataset import HFDataset | |
from huggingface_hub import DatasetCard | |
NUM_THREADS = 1 | |
# Detect if code is running in Colab | |
is_colab = is_google_colab() | |
colab_instruction = "" if is_colab else """ | |
<p>You can skip the queue using Colab: | |
<a href="https://colab.research.google.com/drive/1zNRnX1lXjlGtBMW8U8S9t4eY1cA0D6lm?usp=sharing"> | |
<img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>""" | |
device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶" | |
def numvideos_type(x): | |
x = int(x) | |
if x > 12: | |
raise argparse.ArgumentTypeError("Maximum number of videos is 12") | |
if x < 1: | |
raise argparse.ArgumentTypeError("Minimum number of videos is 12") | |
return x | |
def parse_args(): | |
parser = argparse.ArgumentParser(usage="[arguments] --channel_name --num_videos", | |
description="Program to transcribe YouTube videos.") | |
parser.add_argument("--channel_name", | |
type=str, | |
required=True, | |
help="Name of the channel from where the videos will be transcribed") | |
parser.add_argument("--num_videos", | |
type=numvideos_type, | |
required=True, | |
help="Number of videos (min. 1 - max. 12) to transcribe from --channel_name") | |
parser.add_argument("--hf_token", | |
type=str, | |
required=True, | |
help="Token of your HF account. You need a HF account to upload the dataset") | |
parser.add_argument("--hf_dataset_identifier", | |
type=str, | |
required=True, | |
help="The ID of the repository to push to in the following format: <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.") | |
parser.add_argument("--whisper_model", | |
type=str, | |
required=True, | |
help="Select one of the available whispers models", | |
choices=["tiny", "base", "small", "medium", "large"]) | |
args = parser.parse_args() | |
return args | |
def transcribe(mode: str, | |
channel_name: str, | |
num_videos: int, | |
hf_token: str, | |
hf_dataset_identifier: str, | |
whisper_model: str) -> str: | |
# Create a unique name for the database | |
unique_filename = str(uuid.uuid4()) | |
database_name = unique_filename +".db" | |
create_db(database_name) | |
# Create necessary resources | |
yt_video_processor = YoutubeVideoPreprocessor(mode=mode, | |
serializer=JsonSerializer()) # TODO: Let user select serializer | |
hf_dataset = HFDataset(hf_dataset_identifier) | |
videos_downloaded = hf_dataset.list_of_ids | |
paths, dataset_folder = yt_video_processor.preprocess(channel_name, | |
num_videos, | |
videos_downloaded) | |
nested_listed_length = math.ceil(len(paths) / NUM_THREADS) | |
nested_paths = nest_list(paths, nested_listed_length) | |
data_pipelines = [create_hardcoded_data_pipeline(database_name, whisper_model) for i in range(NUM_THREADS)] | |
# Run pipelines in multiple threads | |
threads = [] | |
for data_pipeline, thread_paths in zip(data_pipelines, nested_paths): | |
threads.append(ThreadedDataPipeline(data_pipeline, thread_paths)) | |
for thread in threads: | |
thread.start() | |
for thread in threads: | |
thread.join() | |
# Fetch entries and print them | |
connection = sqlite3.connect(database_name) | |
cursor = connection.cursor() | |
cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO") | |
videos = cursor.fetchall() | |
num_new_videos = len(videos) | |
dataset = Dataset.from_sql("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO", connection) | |
if (hf_dataset.exist==True) and (hf_dataset.is_empty==False): | |
dataset_to_upload = concatenate_datasets([hf_dataset.dataset["train"], dataset]) | |
else: | |
dataset_to_upload = dataset | |
dataset_to_upload.push_to_hub(hf_dataset_identifier, token=hf_token) | |
card = DatasetCard.load(hf_dataset_identifier) | |
card.data.tags = ["whisper", "whispering", whisper_model] | |
card.data.task_categories = ["automatic-speech-recognition"] | |
card.push_to_hub(hf_dataset_identifier, token=hf_token) | |
# Close connection | |
connection.close() | |
# Remove db | |
os.remove(database_name) | |
try: | |
shutil.rmtree(dataset_folder) | |
except OSError as e: | |
print("Error: %s : %s" % (dataset_folder, e.strerror)) | |
return f"Dataset created or updated at {hf_dataset_identifier}. {num_new_videos} samples were added" | |
with gr.Blocks() as demo: | |
md = """# Use Whisper to create a HF dataset from YouTube videos | |
This space will let you create a HF dataset by transcribing videos from YouTube. | |
Enter the name of the YouTube channel or the URL of a YouTube playlist (in the form https://www.youtube.com/playlist?list=****), | |
and the repo_id of the dataset (you need a HuggingFace account). | |
If the dataset already exists, it will only transcribe videos that are not in the dataset. | |
If it does not exists, it will create the dataset. For using this demo, you need a | |
[Hugging Face token](https://huggingface.co/settings/tokens) with write role. Learn more about [tokens](https://huggingface.co/docs/hub/security-tokens). | |
""" | |
gr.Markdown(md) | |
gr.HTML( | |
f""" | |
<p style="margin-bottom: 10px; font-size: 94%"> | |
Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")} | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
whisper_model = gr.Radio([ | |
"tiny", "base", "small", "medium", "large" | |
], label="Whisper model", value="base") | |
mode = gr.Radio([ | |
"channel_name", "playlist" | |
], label="Get the videos from:", value="channel_name") | |
channel_name = gr.Textbox(label="YouTube Channel or Playlist URL", | |
placeholder="Enter the name of the YouTube channel or the URL of the playlist") | |
num_videos = gr.Slider(1, 20000, value=4, step=1, label="Number of videos") | |
hf_token = gr.Textbox(placeholder="Your HF write access token", type="password") | |
hf_dataset_identifier = gr.Textbox(label = 'Dataset Name', | |
placeholder = "Enter in the format <username>/<repo_name>") | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
output = gr.Text() | |
submit_btn.click(fn=transcribe, inputs=[mode, | |
channel_name, | |
num_videos, | |
hf_token, | |
hf_dataset_identifier, | |
whisper_model], outputs=[output]) | |
gr.Markdown(''' | |
![visitors](https://visitor-badge.glitch.me/badge?page_id=juancopi81.whisper-youtube-2-hf_dataset) | |
''') | |
if not is_colab: | |
demo.queue(concurrency_count=1) | |
demo.launch(debug=True, share=is_colab) |