juancopi81's picture
Duplicate from Whispering-GPT/whisper-youtube-2-hf_dataset
7288748
import math
import os
import argparse
import sqlite3
import shutil
import uuid
from datasets import Dataset, concatenate_datasets
import gradio as gr
import torch
from storing.createdb import create_db
from preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor
from loading.serialization import JsonSerializer
from utils import nest_list, is_google_colab
from datapipeline import create_hardcoded_data_pipeline
from threadeddatapipeline import ThreadedDataPipeline
from dataset.hf_dataset import HFDataset
from huggingface_hub import DatasetCard
NUM_THREADS = 1
# Detect if code is running in Colab
is_colab = is_google_colab()
colab_instruction = "" if is_colab else """
<p>You can skip the queue using Colab:
<a href="https://colab.research.google.com/drive/1zNRnX1lXjlGtBMW8U8S9t4eY1cA0D6lm?usp=sharing">
<img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>"""
device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
def numvideos_type(x):
x = int(x)
if x > 12:
raise argparse.ArgumentTypeError("Maximum number of videos is 12")
if x < 1:
raise argparse.ArgumentTypeError("Minimum number of videos is 12")
return x
def parse_args():
parser = argparse.ArgumentParser(usage="[arguments] --channel_name --num_videos",
description="Program to transcribe YouTube videos.")
parser.add_argument("--channel_name",
type=str,
required=True,
help="Name of the channel from where the videos will be transcribed")
parser.add_argument("--num_videos",
type=numvideos_type,
required=True,
help="Number of videos (min. 1 - max. 12) to transcribe from --channel_name")
parser.add_argument("--hf_token",
type=str,
required=True,
help="Token of your HF account. You need a HF account to upload the dataset")
parser.add_argument("--hf_dataset_identifier",
type=str,
required=True,
help="The ID of the repository to push to in the following format: <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
parser.add_argument("--whisper_model",
type=str,
required=True,
help="Select one of the available whispers models",
choices=["tiny", "base", "small", "medium", "large"])
args = parser.parse_args()
return args
def transcribe(mode: str,
channel_name: str,
num_videos: int,
hf_token: str,
hf_dataset_identifier: str,
whisper_model: str) -> str:
# Create a unique name for the database
unique_filename = str(uuid.uuid4())
database_name = unique_filename +".db"
create_db(database_name)
# Create necessary resources
yt_video_processor = YoutubeVideoPreprocessor(mode=mode,
serializer=JsonSerializer()) # TODO: Let user select serializer
hf_dataset = HFDataset(hf_dataset_identifier)
videos_downloaded = hf_dataset.list_of_ids
paths, dataset_folder = yt_video_processor.preprocess(channel_name,
num_videos,
videos_downloaded)
nested_listed_length = math.ceil(len(paths) / NUM_THREADS)
nested_paths = nest_list(paths, nested_listed_length)
data_pipelines = [create_hardcoded_data_pipeline(database_name, whisper_model) for i in range(NUM_THREADS)]
# Run pipelines in multiple threads
threads = []
for data_pipeline, thread_paths in zip(data_pipelines, nested_paths):
threads.append(ThreadedDataPipeline(data_pipeline, thread_paths))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# Fetch entries and print them
connection = sqlite3.connect(database_name)
cursor = connection.cursor()
cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO")
videos = cursor.fetchall()
num_new_videos = len(videos)
dataset = Dataset.from_sql("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO", connection)
if (hf_dataset.exist==True) and (hf_dataset.is_empty==False):
dataset_to_upload = concatenate_datasets([hf_dataset.dataset["train"], dataset])
else:
dataset_to_upload = dataset
dataset_to_upload.push_to_hub(hf_dataset_identifier, token=hf_token)
card = DatasetCard.load(hf_dataset_identifier)
card.data.tags = ["whisper", "whispering", whisper_model]
card.data.task_categories = ["automatic-speech-recognition"]
card.push_to_hub(hf_dataset_identifier, token=hf_token)
# Close connection
connection.close()
# Remove db
os.remove(database_name)
try:
shutil.rmtree(dataset_folder)
except OSError as e:
print("Error: %s : %s" % (dataset_folder, e.strerror))
return f"Dataset created or updated at {hf_dataset_identifier}. {num_new_videos} samples were added"
with gr.Blocks() as demo:
md = """# Use Whisper to create a HF dataset from YouTube videos
This space will let you create a HF dataset by transcribing videos from YouTube.
Enter the name of the YouTube channel or the URL of a YouTube playlist (in the form https://www.youtube.com/playlist?list=****),
and the repo_id of the dataset (you need a HuggingFace account).
If the dataset already exists, it will only transcribe videos that are not in the dataset.
If it does not exists, it will create the dataset. For using this demo, you need a
[Hugging Face token](https://huggingface.co/settings/tokens) with write role. Learn more about [tokens](https://huggingface.co/docs/hub/security-tokens).
"""
gr.Markdown(md)
gr.HTML(
f"""
<p style="margin-bottom: 10px; font-size: 94%">
Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
</p>
"""
)
with gr.Row():
with gr.Column():
whisper_model = gr.Radio([
"tiny", "base", "small", "medium", "large"
], label="Whisper model", value="base")
mode = gr.Radio([
"channel_name", "playlist"
], label="Get the videos from:", value="channel_name")
channel_name = gr.Textbox(label="YouTube Channel or Playlist URL",
placeholder="Enter the name of the YouTube channel or the URL of the playlist")
num_videos = gr.Slider(1, 20000, value=4, step=1, label="Number of videos")
hf_token = gr.Textbox(placeholder="Your HF write access token", type="password")
hf_dataset_identifier = gr.Textbox(label = 'Dataset Name',
placeholder = "Enter in the format <username>/<repo_name>")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Text()
submit_btn.click(fn=transcribe, inputs=[mode,
channel_name,
num_videos,
hf_token,
hf_dataset_identifier,
whisper_model], outputs=[output])
gr.Markdown('''
![visitors](https://visitor-badge.glitch.me/badge?page_id=juancopi81.whisper-youtube-2-hf_dataset)
''')
if not is_colab:
demo.queue(concurrency_count=1)
demo.launch(debug=True, share=is_colab)