snoop2head
commited on
Commit
ยท
21d2052
1
Parent(s):
9aade45
add app
Browse files- README.md +4 -4
- app.py +172 -0
- requirements.txt +4 -0
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: Kogpt
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
app_file: app.py
|
8 |
pinned: false
|
|
|
1 |
---
|
2 |
+
title: Kogpt Joong 2 Hakgyo
|
3 |
+
emoji: ๐ข
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
sdk: streamlit
|
7 |
app_file: app.py
|
8 |
pinned: false
|
app.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast
|
5 |
+
|
6 |
+
|
7 |
+
model_dir = "snoop2head/kogpt-conditional-2"
|
8 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
9 |
+
model_dir,
|
10 |
+
bos_token="<s>",
|
11 |
+
eos_token="</s>",
|
12 |
+
unk_token="<unk>",
|
13 |
+
pad_token="<pad>",
|
14 |
+
mask_token="<mask>",
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
@st.cache
|
19 |
+
def load_model(model_name):
|
20 |
+
model = AutoModelWithLMHead.from_pretrained(model_name)
|
21 |
+
return model
|
22 |
+
|
23 |
+
|
24 |
+
model = load_model(model_dir)
|
25 |
+
|
26 |
+
|
27 |
+
def find_nth(haystack, needle, n):
|
28 |
+
start = haystack.find(needle)
|
29 |
+
while start >= 0 and n > 1:
|
30 |
+
start = haystack.find(needle, start + len(needle))
|
31 |
+
n -= 1
|
32 |
+
return start
|
33 |
+
|
34 |
+
|
35 |
+
def infer(input_ids, max_length, temperature, top_k, top_p):
|
36 |
+
output_sequences = model.generate(
|
37 |
+
input_ids=input_ids,
|
38 |
+
max_length=max_length,
|
39 |
+
temperature=temperature,
|
40 |
+
top_k=top_k,
|
41 |
+
top_p=top_p,
|
42 |
+
do_sample=True,
|
43 |
+
num_return_sequences=1,
|
44 |
+
)
|
45 |
+
return output_sequences
|
46 |
+
|
47 |
+
|
48 |
+
# prompts
|
49 |
+
st.title("์ผํ์์ ๋ฌ์ธ KoGPT์
๋๋ค ๐ฆ")
|
50 |
+
st.write("ํ
์คํธ๋ฅผ ์
๋ ฅํ๊ณ CTRL+Enter(CMD+Enter)์ ๋๋ฅด์ธ์ ๐ค")
|
51 |
+
|
52 |
+
# text and sidebars
|
53 |
+
default_value = "๋ฐ์๋ฏผ"
|
54 |
+
sent = st.text_area("Text", default_value, max_chars=4, height=275)
|
55 |
+
max_length = st.sidebar.slider("์์ฑ ๋ฌธ์ฅ ๊ธธ์ด๋ฅผ ์ ํํด์ฃผ์ธ์!", min_value=42, max_value=64)
|
56 |
+
temperature = st.sidebar.slider(
|
57 |
+
"Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05
|
58 |
+
)
|
59 |
+
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
|
60 |
+
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
|
61 |
+
|
62 |
+
print("slider sidebars rendering completed")
|
63 |
+
|
64 |
+
# make input sentence
|
65 |
+
emotion_list = ["ํ๋ณต", "์ค๋ฆฝ", "๋ถ๋
ธ", "ํ์ค", "๋๋", "์ฌํ", "๊ณตํฌ"]
|
66 |
+
main_emotion = st.sidebar.radio("์ฃผ์ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list)
|
67 |
+
sub_emotion = st.sidebar.radio("๋ ๋ฒ์งธ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list)
|
68 |
+
|
69 |
+
print("radio sidebars rendering completed")
|
70 |
+
|
71 |
+
# create condition sentence
|
72 |
+
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1)
|
73 |
+
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1)
|
74 |
+
condition_sentence = f"{random_main_logit}๋งํผ {main_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. {random_sub_logit}๋งํผ {sub_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. "
|
75 |
+
condition_plus_input = condition_sentence + sent
|
76 |
+
print(condition_plus_input)
|
77 |
+
|
78 |
+
|
79 |
+
def infer_sentence(
|
80 |
+
condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=3
|
81 |
+
):
|
82 |
+
encoded_prompt = tokenizer.encode(
|
83 |
+
condition_plus_input, add_special_tokens=False, return_tensors="pt"
|
84 |
+
)
|
85 |
+
if encoded_prompt.size()[-1] == 0:
|
86 |
+
input_ids = None
|
87 |
+
else:
|
88 |
+
input_ids = encoded_prompt
|
89 |
+
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
|
90 |
+
print(output_sequences)
|
91 |
+
|
92 |
+
# exclude item that contains "unk"
|
93 |
+
output_sequences = [
|
94 |
+
output_sequence
|
95 |
+
for output_sequence in output_sequences
|
96 |
+
if "unk" not in output_sequence
|
97 |
+
]
|
98 |
+
# choose item that length is longer than 1
|
99 |
+
output_sequences = [
|
100 |
+
output_sequence
|
101 |
+
for output_sequence in output_sequences
|
102 |
+
if len(output_sequence) > 1
|
103 |
+
]
|
104 |
+
generated_sequence = output_sequences[0]
|
105 |
+
print(generated_sequence)
|
106 |
+
|
107 |
+
# print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
|
108 |
+
# generated_sequences = generated_sequence.tolist()
|
109 |
+
# Decode text
|
110 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
111 |
+
print(text)
|
112 |
+
# Remove all text after the stop token
|
113 |
+
stop_token = tokenizer.pad_token
|
114 |
+
print(stop_token)
|
115 |
+
text = text[: text.find(stop_token) if stop_token else None]
|
116 |
+
print(text)
|
117 |
+
|
118 |
+
condition_index = find_nth(text, "๋ฌธ์ฅ์ด๋ค", 2)
|
119 |
+
text = text[condition_index + 5 :]
|
120 |
+
text = text.strip()
|
121 |
+
return text
|
122 |
+
|
123 |
+
|
124 |
+
def make_residual_conditional_samhaengshi(input_letter, condition_sentence):
|
125 |
+
# make letter string into
|
126 |
+
list_samhaengshi = []
|
127 |
+
|
128 |
+
# initializing text and index for iteration purpose
|
129 |
+
index = 0
|
130 |
+
|
131 |
+
# iterating over the input letter string
|
132 |
+
for index, letter_item in enumerate(input_letter):
|
133 |
+
# initializing the input_letter
|
134 |
+
if index == 0:
|
135 |
+
residual_text = letter_item
|
136 |
+
# print('residual_text:', residual_text)
|
137 |
+
|
138 |
+
# infer and add to the output
|
139 |
+
conditional_input = f"{condition_sentence} {residual_text}"
|
140 |
+
inferred_sentence = infer_sentence(conditional_input, tokenizer)
|
141 |
+
if index != 0:
|
142 |
+
# remove previous sentence from the output
|
143 |
+
print("inferred_sentence:", inferred_sentence)
|
144 |
+
inferred_sentence = inferred_sentence.replace(
|
145 |
+
list_samhaengshi[index - 1], ""
|
146 |
+
).strip()
|
147 |
+
else:
|
148 |
+
pass
|
149 |
+
list_samhaengshi.append(inferred_sentence)
|
150 |
+
|
151 |
+
# until the end of the input_letter, give the previous residual_text to the next iteration
|
152 |
+
if index < len(input_letter) - 1:
|
153 |
+
residual_sentence = list_samhaengshi[index]
|
154 |
+
next_letter = input_letter[index + 1]
|
155 |
+
residual_text = (
|
156 |
+
f"{residual_sentence} {next_letter}" # previous sentence + next letter
|
157 |
+
)
|
158 |
+
print("residual_text", residual_text)
|
159 |
+
|
160 |
+
elif index == len(input_letter) - 1: # end of the input_letter
|
161 |
+
# Concatenate strings in the list without intersection
|
162 |
+
|
163 |
+
return list_samhaengshi
|
164 |
+
|
165 |
+
|
166 |
+
return_text = make_residual_conditional_samhaengshi(
|
167 |
+
input_letter=sent, condition_sentence=condition_sentence
|
168 |
+
)
|
169 |
+
|
170 |
+
print(return_text)
|
171 |
+
|
172 |
+
st.write(return_text)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
streamlit
|
3 |
+
torch
|
4 |
+
numpy
|