""" summarize - a module for summarizing text using a model from the Hugging Face model hub """ import logging import pprint as pp logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") import torch from tqdm.auto import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from utils import validate_pytorch2 def load_model_and_tokenizer(model_name: str) -> tuple: """ load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub :param str model_name: the model name/ID on the hub :return tuple: a tuple containing the model and tokenizer """ MODEL_OPTIONS = { "Text Summarizer": "pszemraj/long-t5-tglobal-base-16384-book-summary", "News Artical Summarizer Alpha": "pszemraj/long-t5-tglobal-base-sci-simplify", "News Artical Summarizer Beta": "pszemraj/long-t5-tglobal-base-sci-simplify-elife", "Scientific Document Summarizer Alpha": "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1", "Scientific Document Summarizer Beta": "pszemraj/pegasus-x-large-book-summary", } selected_model_identifier = MODEL_OPTIONS.get(model_name) device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForSeq2SeqLM.from_pretrained( selected_model_identifier, ).to(device) model = model.eval() tokenizer = AutoTokenizer.from_pretrained(selected_model_identifier) logging.info(f"Loaded model {selected_model_identifier} to {device}") if validate_pytorch2(): try: logging.info("Compiling model with Torch 2.0") model = torch.compile(model) except Exception as e: logging.warning(f"Could not compile model with Torch 2.0: {e}") else: logging.info("Torch 2.0 not detected, skipping compilation") return model, tokenizer def summarize_and_score( ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs ) -> tuple: """ summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary Args: ids (): the batch of ids mask (): the attention mask for the batch model (): the model to use for summarization tokenizer (): the tokenizer to use for summarization is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True. **kwargs: any additional arguments to pass to the model Returns: tuple (str, float): the summary, the score for the summary """ ids = ids[None, :] mask = mask[None, :] input_ids = ids.to("cuda") if torch.cuda.is_available() else ids attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask global_attention_mask = torch.zeros_like(attention_mask) # put global attention on token global_attention_mask[:, 0] = 1 if is_general_attention_model: summary_pred_ids = model.generate( input_ids, attention_mask=attention_mask, output_scores=True, return_dict_in_generate=True, **kwargs, ) else: summary_pred_ids = model.generate( input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, output_scores=True, return_dict_in_generate=True, **kwargs, ) summary = tokenizer.batch_decode( summary_pred_ids.sequences, skip_special_tokens=True, remove_invalid_values=True, ) score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4) return summary, score def summarize_via_tokenbatches( input_text: str, model, tokenizer, batch_length=2048, batch_stride=16, min_batch_length=512, **kwargs, ) -> list: """ summarize_via_tokenbatches - summarize a long string via batches of tokens Args: input_text (str): the text to summarize model (): the model to use for summarization tokenizer (): the tokenizer to use for summarization batch_length (int, optional): the length of each batch. Defaults to 2048. batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches. min_batch_length (int, optional): the minimum length of each batch. Defaults to 512. **kwargs: any additional arguments to pass to the model for inference Returns: list: a list of dictionaries containing the input tokens, the summary, and the summary score """ logger = logging.getLogger(__name__) # log all input parameters if batch_length < min_batch_length: logger.warning( f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}" ) batch_length = min_batch_length logger.info(f"input parameters:\n{pp.pformat(kwargs)}") logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}") encoded_input = tokenizer( input_text, padding="max_length", truncation=True, max_length=batch_length, stride=batch_stride, return_overflowing_tokens=True, add_special_tokens=False, return_tensors="pt", ) in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask gen_summaries = [] pbar = tqdm(total=len(in_id_arr)) for _id, _mask in zip(in_id_arr, att_arr): result, score = summarize_and_score( ids=_id, mask=_mask, model=model, tokenizer=tokenizer, **kwargs, ) score = round(float(score), 4) _sum = { "input_tokens": _id, "summary": result, "summary_score": score, } gen_summaries.append(_sum) logger.debug(f"Score for batch: {score}. num chars: {len(repr(result))}") logger.debug(f"Summary:\n\t{result}") pbar.update() pbar.close() return gen_summaries