|
|
|
import numpy as np |
|
import streamlit as st |
|
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast |
|
|
|
model_dir = "snoop2head/kogpt-conditional-2" |
|
tokenizer = PreTrainedTokenizerFast.from_pretrained( |
|
model_dir, |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
unk_token="<unk>", |
|
pad_token="<pad>", |
|
mask_token="<mask>", |
|
) |
|
|
|
|
|
@st.cache |
|
def load_model(model_name): |
|
model = AutoModelWithLMHead.from_pretrained(model_name) |
|
return model |
|
|
|
|
|
model = load_model(model_dir) |
|
print("loaded model completed") |
|
|
|
|
|
def find_nth(haystack, needle, n): |
|
start = haystack.find(needle) |
|
while start >= 0 and n > 1: |
|
start = haystack.find(needle, start + len(needle)) |
|
n -= 1 |
|
return start |
|
|
|
|
|
def infer(input_ids, max_length, temperature, top_k, top_p): |
|
output_sequences = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
) |
|
return output_sequences |
|
|
|
|
|
|
|
st.title("์ฃผ์ด์ง ๊ฐ์ ์ ๋ง๊ฒ ๋ฌธ์ฅ์ ๋ง๋๋ KoGPT์
๋๋ค ๐ฆ") |
|
st.write("์ข์ธก์ ๊ฐ์ ์ํ์ ๋ณํ๋ฅผ ์ฃผ๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋๋ฅด์ธ์ ๐ค") |
|
|
|
|
|
default_value = "์์ํ ๋ฐค๋ค์ด ๊ณ์๋๋ ๋ ์ธ์ ๊ฐ๋ถํฐ ๋๋" |
|
sent = st.text_area("Text", default_value, max_chars=30, height=50) |
|
max_length = st.sidebar.slider("์์ฑ ๋ฌธ์ฅ ๊ธธ์ด๋ฅผ ์ ํํด์ฃผ์ธ์!", min_value=42, max_value=64) |
|
temperature = st.sidebar.slider( |
|
"Temperature", value=0.9, min_value=0.0, max_value=1.0, step=0.05 |
|
) |
|
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0) |
|
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=1.0) |
|
|
|
print("slider sidebars rendering completed") |
|
|
|
|
|
emotion_list = ["ํ๋ณต", "๋๋", "๋ถ๋
ธ", "ํ์ค", "์ฌํ", "๊ณตํฌ", "์ค๋ฆฝ"] |
|
main_emotion = st.sidebar.radio("์ฃผ์ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list) |
|
emotion_list.reverse() |
|
sub_emotion = st.sidebar.radio("๋ ๋ฒ์งธ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list) |
|
|
|
print("radio sidebars rendering completed") |
|
|
|
|
|
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1) |
|
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1) |
|
condition_sentence = f"{random_main_logit}๋งํผ {main_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. {random_sub_logit}๋งํผ {sub_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. " |
|
condition_plus_input = condition_sentence + sent |
|
print(condition_plus_input) |
|
|
|
|
|
def infer_sentence( |
|
condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2 |
|
): |
|
encoded_prompt = tokenizer.encode( |
|
condition_plus_input, add_special_tokens=False, return_tensors="pt" |
|
) |
|
if encoded_prompt.size()[-1] == 0: |
|
input_ids = None |
|
else: |
|
input_ids = encoded_prompt |
|
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p) |
|
print(output_sequences) |
|
|
|
generated_sequence = output_sequences[0] |
|
print(generated_sequence) |
|
|
|
|
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
print(text) |
|
|
|
|
|
stop_token = tokenizer.pad_token |
|
print(stop_token) |
|
text = text[: text.find(stop_token) if stop_token else None] |
|
print(text) |
|
|
|
|
|
condition_index = find_nth(text, "๋ฌธ์ฅ์ด๋ค", 2) |
|
text = text[condition_index + 5 :] |
|
text = text.strip() |
|
return text |
|
|
|
|
|
return_text = infer_sentence( |
|
condition_plus_input=condition_plus_input, tokenizer=tokenizer |
|
) |
|
|
|
print(return_text) |
|
|
|
st.write(return_text) |
|
|