demo / app.py
ldhldh's picture
Duplicate from ldhldh/polyglot_ko_1.3B_PEFT_demo
9ca25de
raw
history blame
4.85 kB
from threading import Thread
import torch
import gradio as gr
from transformers import pipeline,AutoTokenizer, AutoModelForCausalLM, BertTokenizer, BertForSequenceClassification, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel, PeftConfig
import re
from kobert_transformers import get_tokenizer
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
peft_model_id = "ldhldh/polyglot-ko-1.3b_lora_big_8kstep"
#18k > μƒλŒ€μ˜ λ§κΉŒμ§€ ν•˜λŠ” μ΄μŠˆκ°€ 있음
#8k > μ•½κ°„ μ•„μ‰¬μš΄κ°€?
config = PeftConfig.from_pretrained(peft_model_id)
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
#base_model = AutoModelForCausalLM.from_pretrained("EleutherAI/polyglot-ko-3.8b")
#tokenizer = AutoTokenizer.from_pretrained("EleutherAI/polyglot-ko-3.8b")
base_model.eval()
#base_model.config.use_cache = True
model = PeftModel.from_pretrained(base_model, peft_model_id, device_map="auto")
model.eval()
#model.config.use_cache = True
mbti_bert_model_name = "Lanvizu/fine-tuned-klue-bert-base_model_11"
mbti_bert_model = BertForSequenceClassification.from_pretrained(mbti_bert_model_name)
mbti_bert_model.eval()
mbti_bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model_name = "ldhldh/bert_YN_small"
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name)
bert_model.eval()
bert_tokenizer = get_tokenizer()
def mbti_classify(x):
classifier = pipeline("text-classification", model=mbti_bert_model, tokenizer=mbti_bert_tokenizer, return_all_scores=True)
result = classifier([x])
return result[0]
def classify(x):
input_list = bert_tokenizer.batch_encode_plus([x], truncation=True, padding=True, return_tensors='pt')
input_ids = input_list['input_ids'].to(bert_model.device)
attention_masks = input_list['attention_mask'].to(bert_model.device)
outputs = bert_model(input_ids, attention_mask=attention_masks, return_dict=True)
return outputs.logits.argmax(dim=1).cpu().tolist()[0]
def gen(x, top_p, top_k, temperature, max_new_tokens, repetition_penalty):
gened = model.generate(
**tokenizer(
f"{x}",
return_tensors='pt',
return_token_type_ids=False
),
#bad_words_ids = bad_words_ids ,
max_new_tokens=max_new_tokens,
min_new_tokens = 5,
exponential_decay_length_penalty = (max_new_tokens/2, 1.1),
top_p=top_p,
top_k=top_k,
temperature = temperature,
early_stopping=True,
do_sample=True,
eos_token_id=2,
pad_token_id=2,
#stopping_criteria = stopping_criteria,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size = 2
)
model_output = tokenizer.decode(gened[0])
return model_output
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
duplicate_link = "https://huggingface.co/spaces/beomi/KoRWKV-1.5B?duplicate=true"
gr.Markdown(
"duplicated from beomi/KoRWKV-1.5B, baseModel:EleutherAI/polyglot-ko-1.3b"
)
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
placeholder='\\nfriend: 우리 μ—¬ν–‰ 갈래? \\nyou:',
label="User input"
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit")
button_bert = gr.Button(value="bert_Sumit")
button_mbti_bert = gr.Button(value="mbti_bert_Sumit")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=1, maximum=200, value=20, step=1, interactive=True, label="Max New Tokens",
)
top_p = gr.Slider(
minimum=0.05, maximum=1.0, value=0.8, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
)
top_k = gr.Slider(
minimum=5, maximum=100, value=30, step=5, interactive=True, label="Top-k (nucleus sampling)",
)
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=0.5, step=0.1, interactive=True, label="Temperature",
)
repetition_penalty = gr.Slider(
minimum=1.0, maximum=3.0, value=1.2, step=0.1, interactive=True, label="repetition_penalty",
)
button_submit.click(gen, [user_text, top_p, top_k, temperature, max_new_tokens, repetition_penalty], model_output)
button_bert.click(classify, [user_text], model_output)
button_mbti_bert.click(mbti_classify, [user_text], model_output)
demo.queue(max_size=32).launch(enable_queue=True)