Spaces:
Runtime error
Runtime error
import os | |
import re | |
import time | |
import shutil | |
from pathlib import Path | |
from functools import partial | |
from typing import Union, Dict, List | |
import torch | |
from torch.utils.data import DataLoader | |
import datasets | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, PreTrainedTokenizer | |
from huggingface_hub import Repository, create_repo, HfApi | |
from optimum.onnxruntime import ( | |
AutoOptimizationConfig, | |
ORTModelForFeatureExtraction, | |
ORTOptimizer, | |
) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
opt_configs = { | |
"O2": AutoOptimizationConfig.O2(), | |
"O3": AutoOptimizationConfig.O3(), | |
"O4": AutoOptimizationConfig.O4(), | |
} | |
def get_batch_size(device_name: str, model_name: str, opt_level: str): | |
""" | |
TODO: run actual tests | |
T4 has 16GB | |
A10 has 24GB | |
Args: | |
device_name (`str`): | |
The name of the GPU device in use. | |
model_name (`str`): | |
The name of the model in use. | |
opt_level (`str`): | |
The optimization level in use. | |
Returns: | |
`int`: | |
The batch size to use. | |
""" | |
if "small" in model_name: | |
bs = 192 | |
elif "base" in model_name: | |
bs = 128 | |
elif "large" in model_name: | |
bs = 64 | |
else: | |
bs = 32 | |
if "A10" in device_name: | |
bs *= 2 | |
if opt_level == "O4": | |
bs *= 2 | |
return bs | |
def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor): | |
""" | |
Mean pool the token embeddings. | |
Args: | |
last_hidden_state (`tuple`): | |
The output of the model. | |
attention_mask (`torch.Tensor`): | |
The attention mask. | |
Returns: | |
`torch.Tensor`: | |
The mean pooled embeddings. | |
""" | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
) | |
return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"): | |
""" | |
Load a dataset from the HuggingFace Hub. Will be streaming so | |
as to not load the whole dataset to local storage. | |
Args: | |
ds_name (`str`): | |
The name of the dataset to load. | |
ds_config (`str`, *optional*, Defaults to `None`): | |
The configuration of the dataset to load. | |
ds_split (`str`, *optional*, Defaults to `"train"`): | |
The split of the dataset to load. | |
Returns: | |
ds (`datasets.IterableDataset`): | |
The loaded dataset. | |
""" | |
if ds_config == "": | |
ds_config = None | |
if ds_name == "wikipedia": | |
pattern = re.compile(r"[^a-zA-Z0-9]") | |
folder = Path("/data") / pattern.sub("", ds_name+ds_config) | |
files = list(map(str, folder.glob("chunk_*"))) | |
return load_dataset("parquet", data_files=files, split="train") | |
ds = load_dataset(ds_name, ds_config, split=ds_split) | |
return ds | |
def download_wikipedia(ds_name, ds_config, num2skip, num2embed): | |
ds = load_dataset(ds_name, ds_config, streaming=True, split="train") | |
def gen(): | |
if num2embed > 0: | |
for example in ds.skip(num2skip).take(num2embed): | |
yield {"text": example["text"]} | |
else: | |
for example in ds.skip(num2skip): | |
yield {"text": example["text"]} | |
ds2 = Dataset.from_generator(gen) | |
chunk_size = 20_000 | |
filenames = [] | |
pattern = re.compile(r"[^a-zA-Z0-9]") | |
folder = Path("/data") / pattern.sub("", ds_name+ds_config) | |
folder.mkdir(exist_ok=True, parents=True) | |
for chunk_num, start_idx in enumerate(range(0, len(ds2), chunk_size)): | |
end_idx = min(start_idx + chunk_size, len(ds2)) | |
temp = ds2.select(range(start_idx, end_idx)) | |
temp.to_parquet(str(folder / f"chunk_{chunk_num}")) | |
filenames.append(str(folder / f"chunk_{chunk_num}")) | |
return load_dataset("parquet", data_files=filenames, split="train") | |
def get_model_and_tokenizer(model_name: str, optimization_level: str, progress): | |
""" | |
Load the model and tokenizer from the HuggingFace Hub. | |
If the model is not already optimized, optimize it and save it to the local directory. | |
Args: | |
model_name (`str`): | |
The name of the model to load. | |
optimization_level (`str`): | |
The optimization level to use. Should be one of `"O2"`, `"O3"`, or `"O4"`. | |
Returns: | |
model (`ORTModelForFeatureExtraction`): | |
The optimized model. | |
tokenizer (`PreTrainedTokenizer`): | |
The tokenizer. | |
""" | |
optimized_model_name = f"model_optimized_{optimization_level}.onnx" | |
model_dir = Path(model_name.replace("/", "_")) | |
if not (model_dir / optimized_model_name).exists(): | |
if progress is not None: | |
progress(0.2, "Downloading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.save_pretrained(model_dir) | |
if progress is not None: | |
progress(0.4, "Downloading model...") | |
model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True) | |
model.save_pretrained(model_dir) | |
optimizer = ORTOptimizer.from_pretrained(model) | |
optimization_config = opt_configs[optimization_level] | |
if progress is not None: | |
progress(0.6, "Optimizing model...") | |
optimizer.optimize(save_dir=model_dir, optimization_config=optimization_config) | |
Path(model_dir / "model_optimized.onnx").rename( | |
model_dir / optimized_model_name | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
if progress is not None: | |
progress(0.8, "Loading optimized model and tokenizer...") | |
return ( | |
ORTModelForFeatureExtraction.from_pretrained( | |
model_dir, | |
file_name=optimized_model_name, | |
provider="CUDAExecutionProvider", | |
), | |
tokenizer, | |
) | |
def tokenize( | |
examples: Dict[str, List[str]], | |
tokenizer: PreTrainedTokenizer, | |
column_name: str = "text", | |
padding: Union[bool, str] = True, | |
max_length: int = 512, | |
): | |
""" | |
Tokenize the examples using the tokenizer. | |
Args: | |
examples (`Dict[str, List[str]]`): | |
examples to tokenize | |
tokenizer (`PreTrainedTokenizer`): | |
tokenizer to use | |
column_name (`str`, *optional*, defaults to `text`): | |
column name to use for tokenization. Defaults to `text` | |
padding (`bool`, *optional*, defaults to `True`): | |
whether to pad the examples. Defaults to `True` | |
Use `"max_length"` if using `O4` optimization level | |
If `True`, the batch will be padded to the longest in the batch. | |
max_length (`int`, *optional*, Defaults to `512`): | |
max length to use for the model. Defaults to `512`. | |
Any sequences longer will be truncated. | |
If padding is `"max_length"`, the padding will be added until the sequence | |
is of length `max_length`. | |
Returns: | |
`Dict[str, List[List[int]]]`: | |
tokenized examples | |
""" | |
# TODO: add lengths, sort by length, use dynamic padding | |
# TODO: option for controlling length for models that can go shorter/longer than 512 | |
return tokenizer( | |
examples[column_name], truncation=True, padding=padding, max_length=max_length | |
) | |
def collate_fn(examples, tokenizer=None, padding=None, column_name="text"): | |
try: | |
keys = examples[0].keys() | |
except KeyError: | |
print(examples) | |
else: | |
batch = {k: [] for k in examples[0].keys()} | |
tokenized = tokenizer( | |
[x[column_name] for x in examples], | |
truncation=True, | |
padding=padding, | |
max_length=512, | |
return_tensors="pt" | |
) | |
tokenized[column_name] = [x[column_name] for x in examples] | |
return tokenized | |
# for example in examples: | |
# for k, v in example.items(): | |
# batch[k].append(v) | |
# return { | |
# k: torch.tensor(v, dtype=torch.long) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items() | |
# } | |
def batch_embed( | |
ds: datasets.IterableDataset, | |
model: ORTModelForFeatureExtraction, | |
tokenizer: PreTrainedTokenizer, | |
model_name: str, | |
column_name: str, | |
new_dataset_id: str, | |
opt_level: str, | |
upload_batch_size: int = 10_000, | |
map_batch_size: int = 2000, | |
num2skip: int = 0, | |
num2embed: int = -1, | |
progress=None, | |
): | |
""" | |
Run the model on the dataset and upload the embeddings to the hub. | |
Args: | |
ds (`datasets.Dataset`): | |
dataset to embed. From `load_hf_dataset` | |
model (`ORTModelForFeatureExtraction`): | |
model to use for embedding. From `get_model_and_tokenizer` | |
tokenizer (`AutoTokenizer`): | |
tokenizer to use for embedding. From `get_model_and_tokenizer` | |
model_name (`str`): | |
name of the model to use. Used to determine batch size. | |
column_name (`str`): | |
column name to use for embedding. Default option in gradio app is `text` | |
new_dataset_id (`str`): | |
id of the new dataset to create. Should include username or organization. | |
e.g. nbroad/new-embeddings | |
opt_level (`str`): | |
optimization level to use. Should be one of `O2`, `O3`, `O4` | |
See here for more details on optimization levels: | |
https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration | |
upload_batch_size (`int`, *optional*, defaults to `10_000`): | |
number of embeddings to upload at once. Defaults to 10,000. | |
map_batch_size (`int`, *optional*, defaults to `2000`): | |
number of examples to tokenize at once. Defaults to 2000. | |
num2skip (`int`, *optional*, defaults to `0`): | |
number of examples to skip. Defaults to 0. | |
num2embed (`int`, *optional*, defaults to `-1`): | |
number of examples to embed. Defaults to -1, which means all examples. | |
Returns: | |
current_count (`int`): | |
number of examples embedded so far | |
time_taken (`float`): | |
time taken to embed the examples in seconds | |
""" | |
api = HfApi( | |
token=os.environ["HF_TOKEN"], | |
) | |
username = api.whoami()["name"] | |
if "/" not in new_dataset_id: | |
new_dataset_id = username + "/" + new_dataset_id | |
repo = init_git_repo(new_dataset_id) | |
# ds = ds.map( | |
# tokenize, | |
# batched=True, | |
# batch_size=map_batch_size, | |
# fn_kwargs={ | |
# "tokenizer": tokenizer, | |
# "column_name": column_name, | |
# "padding": "max_length" if opt_level == "O4" else True, | |
# }, | |
# ) | |
embeds = [] | |
texts = [] | |
# last_count keeps track of how many had been embedded since last push | |
last_count = 0 | |
# current count keeps track of how many have been embedded in total | |
current_count = 0 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level) | |
# skip through some examples if specified | |
if num2skip > 0: | |
ds = ds.skip(num2skip) | |
start_time = time.time() | |
for batch in DataLoader( | |
ds, | |
batch_size=inference_bs, | |
shuffle=False, | |
num_workers=2, | |
pin_memory=True, | |
drop_last=False, | |
collate_fn=partial( | |
collate_fn, | |
column_name=column_name, | |
tokenizer=tokenizer, | |
padding="max_length" if opt_level == "O4" else True | |
) | |
): | |
ids = batch["input_ids"].to(device) | |
mask = batch["attention_mask"].to(device) | |
t_ids = torch.zeros_like(ids) | |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids) | |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist()) | |
texts.extend(batch[column_name]) | |
current_count += ids.shape[0] | |
# Check if we have embedded enough examples | |
if current_count >= num2embed: | |
diff = current_count - num2embed | |
embeds = embeds[:-diff] | |
texts = texts[:-diff] | |
current_count = num2embed | |
break | |
# Periodically upload to the hub | |
if len(embeds) > upload_batch_size: | |
push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api) | |
embeds = [] | |
texts = [] | |
last_count = current_count | |
# Provide updates | |
if progress is not None: | |
progress( | |
(current_count, None), | |
"Embedding docs...", | |
total=None, | |
unit="Docs Embedded", | |
) | |
time_taken = time.time() - start_time | |
# If there are any remaining embeddings, upload them | |
if len(embeds) > 0: | |
push_to_repo(new_dataset_id, last_count, current_count, embeds, texts, api) | |
return current_count - num2skip, time_taken | |
def init_git_repo(repo_id: str): | |
""" | |
Initialize a git repo for the new dataset. | |
***Removes existing local folder if exists*** | |
Args: | |
repo_id (`str`): | |
id of the new dataset to create. Should include username or organization. | |
e.g. nbroad/new-embeddings | |
""" | |
local_dir = repo_id.replace("/", "_") | |
create_repo( | |
repo_id, | |
repo_type="dataset", | |
token=os.environ["HF_TOKEN"], | |
private=True, | |
exist_ok=True, | |
) | |
try: | |
repo = Repository( | |
local_dir=local_dir, | |
clone_from=repo_id, | |
repo_type="dataset", | |
token=os.environ["HF_TOKEN"], | |
skip_lfs_files=True, | |
) | |
except EnvironmentError: | |
shutil.rmtree(local_dir) | |
repo = Repository( | |
local_dir=local_dir, | |
clone_from=repo_id, | |
repo_type="dataset", | |
token=os.environ["HF_TOKEN"], | |
skip_lfs_files=True, | |
) | |
if repo is not None: | |
repo.git_pull() | |
return repo | |
def push_to_repo( | |
repo_id: str, | |
last_count: int, | |
current_count: int, | |
embeds: List[List[float]], | |
texts: List[str], | |
api: HfApi, | |
): | |
""" | |
Push embeddings to the repo. | |
Args: | |
repo_id (`str`): | |
id of the new dataset to create. Should include username or organization. | |
last_count (`int`): | |
last count of embeddings. | |
This is the number of embeddings that have already been pushed. | |
current_count (`int`): | |
current count of embeddings. | |
This is the number of embeddings that have been pushed after this batch. | |
embeds (`List[List[float]]`): | |
list of embeddings to push to the repo | |
texts (`List[str]`): | |
list of texts to push to the repo | |
api (`huggingface_hub.HfApi`): | |
api to use to push to the repo | |
""" | |
temp_ds = Dataset.from_dict( | |
{ | |
"embedding": embeds, | |
"text": texts, | |
} | |
) | |
local_dir = repo_id.replace("/", "_") | |
data_dir = Path(local_dir) / "data" | |
data_dir.mkdir(exist_ok=True, parents=True) | |
# use zfill so sorting puts the files in order | |
filename = f"embeddings_{str(last_count).zfill(8)}_{current_count}.parquet" | |
filepath = str(data_dir / filename) | |
temp_ds.to_parquet(filepath) | |
files = sorted(list(data_dir.glob("*.parquet"))) | |
api.upload_file( | |
path_or_fileobj=filepath, | |
path_in_repo=f"data/{filename}", | |
repo_id=repo_id, | |
repo_type="dataset", | |
run_as_future=True, | |
token=os.environ["HF_TOKEN"], | |
commit_message=f"Embedded examples {last_count} thru {current_count}", | |
) | |
# Delete old files | |
if len(files) > 4: | |
for file in files[:2]: | |
file.unlink() | |