|
from threading import Thread |
|
|
|
import torch |
|
import gradio as gr |
|
from transformers import pipeline,AutoTokenizer, AutoModelForCausalLM, BertTokenizer, BertForSequenceClassification, StoppingCriteria, StoppingCriteriaList |
|
from peft import PeftModel, PeftConfig |
|
import re |
|
from kobert_transformers import get_tokenizer |
|
|
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print("Running on device:", torch_device) |
|
print("CPU threads:", torch.get_num_threads()) |
|
|
|
peft_model_id = "ldhldh/1.3_40kstep" |
|
|
|
|
|
|
|
|
|
|
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path) |
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
|
|
|
|
|
|
base_model.eval() |
|
base_model.config.use_cache = True |
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, peft_model_id, device_map="auto") |
|
model.eval() |
|
model.config.use_cache = True |
|
|
|
|
|
mbti_bert_model_name = "Lanvizu/fine-tuned-klue-bert-base_model_11" |
|
mbti_bert_model = BertForSequenceClassification.from_pretrained(mbti_bert_model_name) |
|
mbti_bert_model.eval() |
|
mbti_bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
bert_model_name = "ldhldh/bert_YN_small" |
|
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name) |
|
bert_model.eval() |
|
bert_tokenizer = get_tokenizer() |
|
|
|
|
|
def mbti_classify(x): |
|
classifier = pipeline("text-classification", model=mbti_bert_model, tokenizer=mbti_bert_tokenizer, return_all_scores=True) |
|
result = classifier([x]) |
|
return result[0] |
|
|
|
|
|
def classify(x): |
|
input_list = bert_tokenizer.batch_encode_plus([x], truncation=True, padding=True, return_tensors='pt') |
|
input_ids = input_list['input_ids'].to(bert_model.device) |
|
attention_masks = input_list['attention_mask'].to(bert_model.device) |
|
outputs = bert_model(input_ids, attention_mask=attention_masks, return_dict=True) |
|
return outputs.logits.argmax(dim=1).cpu().tolist()[0] |
|
|
|
def gen(x, top_p, top_k, temperature, max_new_tokens, repetition_penalty): |
|
gened = model.generate( |
|
**tokenizer( |
|
f"{x}", |
|
return_tensors='pt', |
|
return_token_type_ids=False |
|
), |
|
|
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens = 5, |
|
exponential_decay_length_penalty = (max_new_tokens/2, 1.1), |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature = temperature, |
|
early_stopping=True, |
|
do_sample=True, |
|
eos_token_id=2, |
|
pad_token_id=2, |
|
|
|
repetition_penalty=repetition_penalty, |
|
no_repeat_ngram_size = 2 |
|
) |
|
|
|
model_output = tokenizer.decode(gened[0]) |
|
return model_output |
|
|
|
def reset_textbox(): |
|
return gr.update(value='') |
|
|
|
|
|
with gr.Blocks() as demo: |
|
duplicate_link = "https://huggingface.co/spaces/beomi/KoRWKV-1.5B?duplicate=true" |
|
gr.Markdown( |
|
"duplicated from beomi/KoRWKV-1.5B, baseModel:EleutherAI/polyglot-ko-1.3b" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
user_text = gr.Textbox( |
|
placeholder='\\nfriend: μ°λ¦¬ μ¬ν κ°λ? \\nyou:', |
|
label="User input" |
|
) |
|
model_output = gr.Textbox(label="Model output", lines=10, interactive=False) |
|
button_submit = gr.Button(value="Submit") |
|
button_bert = gr.Button(value="bert_Sumit") |
|
button_mbti_bert = gr.Button(value="mbti_bert_Sumit") |
|
with gr.Column(scale=1): |
|
max_new_tokens = gr.Slider( |
|
minimum=1, maximum=200, value=20, step=1, interactive=True, label="Max New Tokens", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.05, maximum=1.0, value=0.8, step=0.05, interactive=True, label="Top-p (nucleus sampling)", |
|
) |
|
top_k = gr.Slider( |
|
minimum=5, maximum=100, value=30, step=5, interactive=True, label="Top-k (nucleus sampling)", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, maximum=2.0, value=0.5, step=0.1, interactive=True, label="Temperature", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=1.0, maximum=3.0, value=1.2, step=0.1, interactive=True, label="repetition_penalty", |
|
) |
|
|
|
button_submit.click(gen, [user_text, top_p, top_k, temperature, max_new_tokens, repetition_penalty], model_output) |
|
button_bert.click(classify, [user_text], model_output) |
|
button_mbti_bert.click(mbti_classify, [user_text], model_output) |
|
demo.queue(max_size=32).launch(enable_queue=True) |