Spaces:
Runtime error
Runtime error
File size: 4,160 Bytes
42f072b d994b45 42f072b 9d1ac35 42f072b 1ff94f4 42f072b 9d1ac35 42f072b 20dc449 42f072b 9d1ac35 42f072b 9d1ac35 42f072b d994b45 42f072b d994b45 42f072b 9d1ac35 20d4ded 9d1ac35 cf245ed 20d4ded cf245ed 20d4ded 6402181 20d4ded 6402181 20d4ded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
from nltk import sent_tokenize
import nltk
from tqdm import tqdm
import gradio as gr
from transformers import T5ForConditionalGeneration, T5Tokenizer
nltk.download("punkt")
# autodetect the available device
GPU_IDX = 1 # which GPU to use
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
assert GPU_IDX < num_gpus, f"GPU index {GPU_IDX} not available."
device = torch.device(f"cuda:{GPU_IDX}")
print(f"Using GPU: {GPU_IDX}")
else:
print("CUDA is not available. Using CPU instead.")
device = torch.device("cpu")
# Configuration for models and their adapters
model_config = {
"Base Model": "polygraf-ai/poly-humanizer-base",
"Large Model": "polygraf-ai/poly-humanizer-large",
"XL Model": {
"path": "google/flan-t5-xl",
"adapters": {
"XL Model Adapter": "polygraf-ai/poly-humanizer-XL-adapter",
# "XL Law Model Adapter": "polygraf-ai/poly-humanizer-XL-law-adapter",
# "XL Marketing Model Adapter": "polygraf-ai/marketing-cleaned-13K-grad-acum-4-full",
# "XL Child Style Model Adapter": "polygraf-ai/poly-humanizer-XL-children-adapter-checkpoint-4000",
},
},
}
# cache the base models, tokenizers, and adapters
models, tokenizers = {}, {}
for name, config in model_config.items():
path = config if isinstance(config, str) else config["path"]
# initialize model and tokenizer
model = T5ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)
models[name] = model
tokenizers[name] = T5Tokenizer.from_pretrained(path)
# load all avalable adapters, each being additional roughly 150M parameters
if isinstance(config, dict) and "adapters" in config:
for adapter_name, adapter_path in config["adapters"].items():
model.load_adapter(adapter_path, adapter_name=adapter_name)
print(f"Loaded adapter: {adapter_name}, Num. params: {model.num_parameters()}")
def paraphrase_text(
text,
progress=gr.Progress(),
model_name="Base Model",
temperature=1.2,
repetition_penalty=1.0,
top_k=50,
length_penalty=1.0,
):
progress(0, desc="Starting to Humanize")
progress(0.05)
# select the model, tokenizer and adapter
if "XL" in model_name: # dynamic adapter load/unload for XL models
# all adapter models use the XL model as the base
tokenizer, model = tokenizers["XL Model"], models["XL Model"]
# set the adapter if it's not already set
if model.active_adapters() != [f"{model_name} Adapter"]:
model.set_adapter(f"{model_name} Adapter")
print(f"Using adapter: {model_name} Adapter")
else:
tokenizer = tokenizers[model_name]
model = models[model_name]
# Split the text into paragraphs
paragraphs = text.split("\n")
humanized_paragraphs = []
for paragraph in progress.tqdm(paragraphs, desc="Humanizing"):
# paraphrase each chunk of text
sentences = sent_tokenize(paragraph)
paraphrases = []
for sentence in sentences:
sentence = sentence.strip()
if len(sentence) == 0:
continue
inputs = tokenizer(
"Please paraphrase this sentence: " + sentence,
return_tensors="pt",
).to(device)
outputs = model.generate(
**inputs,
do_sample=True,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_length=128,
top_k=top_k,
length_penalty=length_penalty,
)
paraphrased_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
paraphrases.append(paraphrased_sentence)
print(f"\nOriginal: {sentence}")
print(f"Paraphrased: {paraphrased_sentence}")
combined_paraphrase = " ".join(paraphrases)
humanized_paragraphs.append(combined_paraphrase)
humanized_text = "\n\n".join(humanized_paragraphs)
return humanized_text
|