import os import time import shutil from pathlib import Path from typing import Union, Dict, List import torch from torch.utils.data import DataLoader import datasets from datasets import load_dataset, Dataset from transformers import AutoTokenizer, PreTrainedTokenizer from huggingface_hub import Repository, create_repo, HfApi from optimum.onnxruntime import ( AutoOptimizationConfig, ORTModelForFeatureExtraction, ORTOptimizer, ) os.environ["TOKENIZERS_PARALLELISM"] = "false" opt_configs = { "O2": AutoOptimizationConfig.O2(), "O3": AutoOptimizationConfig.O3(), "O4": AutoOptimizationConfig.O4(), } def get_batch_size(device_name: str, model_name: str, opt_level: str): """ TODO: run actual tests T4 has 16GB A10 has 24GB Args: device_name (`str`): The name of the GPU device in use. model_name (`str`): The name of the model in use. opt_level (`str`): The optimization level in use. Returns: `int`: The batch size to use. """ if "small" in model_name: bs = 192 elif "base" in model_name: bs = 128 elif "large" in model_name: bs = 64 else: bs = 32 if "A10" in device_name: bs *= 2 if opt_level == "O4": bs *= 2 return bs def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor): """ Mean pool the token embeddings. Args: last_hidden_state (`tuple`): The output of the model. attention_mask (`torch.Tensor`): The attention mask. Returns: `torch.Tensor`: The mean pooled embeddings. """ input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() ) return torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"): """ Load a dataset from the HuggingFace Hub. Will be streaming so as to not load the whole dataset to local storage. Args: ds_name (`str`): The name of the dataset to load. ds_config (`str`, *optional*, Defaults to `None`): The configuration of the dataset to load. ds_split (`str`, *optional*, Defaults to `"train"`): The split of the dataset to load. Returns: ds (`datasets.IterableDataset`): The loaded dataset. """ if ds_config == "": ds_config = None ds = load_dataset(ds_name, ds_config, split=ds_split, streaming=True) return ds def get_model_and_tokenizer(model_name: str, optimization_level: str, progress): """ Load the model and tokenizer from the HuggingFace Hub. If the model is not already optimized, optimize it and save it to the local directory. Args: model_name (`str`): The name of the model to load. optimization_level (`str`): The optimization level to use. Should be one of `"O2"`, `"O3"`, or `"O4"`. Returns: model (`ORTModelForFeatureExtraction`): The optimized model. tokenizer (`PreTrainedTokenizer`): The tokenizer. """ optimized_model_name = f"model_optimized_{optimization_level}.onnx" model_dir = Path(model_name.replace("/", "_")) if not (model_dir / optimized_model_name).exists(): if progress is not None: progress(0.2, "Downloading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.save_pretrained(model_dir) if progress is not None: progress(0.4, "Downloading model...") model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True) model.save_pretrained(model_dir) optimizer = ORTOptimizer.from_pretrained(model) optimization_config = opt_configs[optimization_level] if progress is not None: progress(0.6, "Optimizing model...") optimizer.optimize(save_dir=model_dir, optimization_config=optimization_config) Path(model_dir / "model_optimized.onnx").rename( model_dir / optimized_model_name ) else: tokenizer = AutoTokenizer.from_pretrained(model_dir) if progress is not None: progress(0.8, "Loading optimized model and tokenizer...") return ( ORTModelForFeatureExtraction.from_pretrained( model_dir, file_name=optimized_model_name, provider="CUDAExecutionProvider", ), tokenizer, ) def tokenize( examples: Dict[str, List[str]], tokenizer: PreTrainedTokenizer, column_name: str = "text", padding: Union[bool, str] = True, max_length: int = 512, ): """ Tokenize the examples using the tokenizer. Args: examples (`Dict[str, List[str]]`): examples to tokenize tokenizer (`PreTrainedTokenizer`): tokenizer to use column_name (`str`, *optional*, defaults to `text`): column name to use for tokenization. Defaults to `text` padding (`bool`, *optional*, defaults to `True`): whether to pad the examples. Defaults to `True` Use `"max_length"` if using `O4` optimization level If `True`, the batch will be padded to the longest in the batch. max_length (`int`, *optional*, Defaults to `512`): max length to use for the model. Defaults to `512`. Any sequences longer will be truncated. If padding is `"max_length"`, the padding will be added until the sequence is of length `max_length`. Returns: `Dict[str, List[List[int]]]`: tokenized examples """ # TODO: add lengths, sort by length, use dynamic padding # TODO: option for controlling length for models that can go shorter/longer than 512 return tokenizer( examples[column_name], truncation=True, padding=padding, max_length=max_length ) def collate_fn(examples, tokenizer=None, padding=None, device=None): try: keys = examples[0].keys() except KeyError: print(examples) else: batch = {k: [] for k in examples[0].keys()} for example in examples: for k, v in example.items(): batch[k].append(v) return { k: torch.tensor(v, dtype=torch.long, device=device) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items() } @torch.inference_mode() def batch_embed( ds: datasets.IterableDataset, model: ORTModelForFeatureExtraction, tokenizer: PreTrainedTokenizer, model_name: str, column_name: str, new_dataset_id: str, opt_level: str, upload_batch_size: int = 10_000, map_batch_size: int = 2000, num2skip: int = 0, num2embed: int = -1, progress=None, ): """ Run the model on the dataset and upload the embeddings to the hub. Args: ds (`datasets.Dataset`): dataset to embed. From `load_hf_dataset` model (`ORTModelForFeatureExtraction`): model to use for embedding. From `get_model_and_tokenizer` tokenizer (`AutoTokenizer`): tokenizer to use for embedding. From `get_model_and_tokenizer` model_name (`str`): name of the model to use. Used to determine batch size. column_name (`str`): column name to use for embedding. Default option in gradio app is `text` new_dataset_id (`str`): id of the new dataset to create. Should include username or organization. e.g. nbroad/new-embeddings opt_level (`str`): optimization level to use. Should be one of `O2`, `O3`, `O4` See here for more details on optimization levels: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization#optimization-configuration upload_batch_size (`int`, *optional*, defaults to `10_000`): number of embeddings to upload at once. Defaults to 10,000. map_batch_size (`int`, *optional*, defaults to `2000`): number of examples to tokenize at once. Defaults to 2000. num2skip (`int`, *optional*, defaults to `0`): number of examples to skip. Defaults to 0. num2embed (`int`, *optional*, defaults to `-1`): number of examples to embed. Defaults to -1, which means all examples. Returns: current_count (`int`): number of examples embedded so far time_taken (`float`): time taken to embed the examples in seconds """ api = HfApi( token=os.environ["HF_TOKEN"], ) username = api.whoami()["name"] if "/" not in new_dataset_id: new_dataset_id = username + "/" + new_dataset_id repo = init_git_repo(new_dataset_id) ds = ds.map( tokenize, batched=True, batch_size=map_batch_size, fn_kwargs={ "tokenizer": tokenizer, "column_name": column_name, "padding": "max_length" if opt_level == "O4" else True, }, ) embeds = [] texts = [] # last_count keeps track of how many had been embedded since last push last_count = 0 # current count keeps track of how many have been embedded in total current_count = 0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level) # skip through some examples if specified if num2skip > 0: ds = ds.skip(num2skip) start_time = time.time() for batch in DataLoader( ds, batch_size=inference_bs, shuffle=False, num_workers=1, pin_memory=True, drop_last=False, ): batch = collate_fn(batch, device=device) ids = batch["input_ids"] mask = batch["attention_mask"] t_ids = torch.zeros_like(ids) outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids) embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist()) texts.extend([b[column_name] for b in batch]) current_count += ids.shape[0] # Check if we have embedded enough examples if current_count >= num2embed: diff = current_count - num2embed embeds = embeds[:-diff] texts = texts[:-diff] current_count = num2embed break # Periodically upload to the hub if len(embeds) > upload_batch_size: push_to_repo(repo, last_count, current_count, embeds, texts) embeds = [] texts = [] last_count = current_count # Provide updates if progress is not None: progress( (current_count, None), "Embedding docs...", total=None, unit="Docs Embedded", ) time_taken = time.time() - start_time # If there are any remaining embeddings, upload them if len(embeds) > 0: push_to_repo(repo, last_count, current_count, embeds, texts) return current_count - num2skip, time_taken def init_git_repo(repo_id: str): """ Initialize a git repo for the new dataset. ***Removes existing local folder if exists*** Args: repo_id (`str`): id of the new dataset to create. Should include username or organization. e.g. nbroad/new-embeddings """ local_dir = repo_id.replace("/", "_") create_repo( repo_id, repo_type="dataset", token=os.environ["HF_TOKEN"], private=True, exist_ok=True, ) try: repo = Repository( local_dir=local_dir, clone_from=repo_id, repo_type="dataset", token=os.environ["HF_TOKEN"], skip_lfs_files=True, ) except EnvironmentError: shutil.rmtree(local_dir) repo = Repository( local_dir=local_dir, clone_from=repo_id, repo_type="dataset", token=os.environ["HF_TOKEN"], skip_lfs_files=True, ) if repo is not None: repo.git_pull() return repo def push_to_repo( repo_id: str, last_count: int, current_count: int, embeds: List[List[float]], texts: List[str], api: HfApi, ): """ Push embeddings to the repo. Args: repo_id (`str`): id of the new dataset to create. Should include username or organization. last_count (`int`): last count of embeddings. This is the number of embeddings that have already been pushed. current_count (`int`): current count of embeddings. This is the number of embeddings that have been pushed after this batch. embeds (`List[List[float]]`): list of embeddings to push to the repo texts (`List[str]`): list of texts to push to the repo api (`huggingface_hub.HfApi`): api to use to push to the repo """ temp_ds = Dataset.from_dict( { "embedding": embeds, "text": texts, } ) local_dir = repo_id.replace("/", "_") data_dir = Path(local_dir) / "data" data_dir.mkdir(exist_ok=True, parents=True) # use zfill so sorting puts the files in order filename = f"embeddings_{str(last_count).zfill(8)}_{current_count}.parquet" filepath = str(data_dir / filename) temp_ds.to_parquet(filepath) files = sorted(list(data_dir.glob("*.parquet"))) if len(files) == 1: api.upload_folder( folder_path=str(data_dir), repo_id=repo_id, repo_type="dataset", run_as_future=True, token=os.environ["HF_TOKEN"], commit_message=f"Embedded examples {last_count} thru {current_count} with folder", ) else: api.upload_file( path_or_fileobj=filepath, path_in_repo=f"data/{filename}", repo_id=repo_id, repo_type="dataset", run_as_future=True, token=os.environ["HF_TOKEN"], commit_message=f"Embedded examples {last_count} thru {current_count}", ) # Delete old files if len(files) > 4: for file in files[:2]: file.unlink()