snoop2head commited on
Commit
21d2052
ยท
1 Parent(s): 9aade45
Files changed (3) hide show
  1. README.md +4 -4
  2. app.py +172 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Kogpt Samhaengshi
3
- emoji: ๐ŸŒ
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
 
1
  ---
2
+ title: Kogpt Joong 2 Hakgyo
3
+ emoji: ๐Ÿข
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import streamlit as st
4
+ from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast
5
+
6
+
7
+ model_dir = "snoop2head/kogpt-conditional-2"
8
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
9
+ model_dir,
10
+ bos_token="<s>",
11
+ eos_token="</s>",
12
+ unk_token="<unk>",
13
+ pad_token="<pad>",
14
+ mask_token="<mask>",
15
+ )
16
+
17
+
18
+ @st.cache
19
+ def load_model(model_name):
20
+ model = AutoModelWithLMHead.from_pretrained(model_name)
21
+ return model
22
+
23
+
24
+ model = load_model(model_dir)
25
+
26
+
27
+ def find_nth(haystack, needle, n):
28
+ start = haystack.find(needle)
29
+ while start >= 0 and n > 1:
30
+ start = haystack.find(needle, start + len(needle))
31
+ n -= 1
32
+ return start
33
+
34
+
35
+ def infer(input_ids, max_length, temperature, top_k, top_p):
36
+ output_sequences = model.generate(
37
+ input_ids=input_ids,
38
+ max_length=max_length,
39
+ temperature=temperature,
40
+ top_k=top_k,
41
+ top_p=top_p,
42
+ do_sample=True,
43
+ num_return_sequences=1,
44
+ )
45
+ return output_sequences
46
+
47
+
48
+ # prompts
49
+ st.title("์‚ผํ–‰์‹œ์˜ ๋‹ฌ์ธ KoGPT์ž…๋‹ˆ๋‹ค ๐Ÿฆ„")
50
+ st.write("ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜๊ณ  CTRL+Enter(CMD+Enter)์„ ๋ˆ„๋ฅด์„ธ์š” ๐Ÿค—")
51
+
52
+ # text and sidebars
53
+ default_value = "๋ฐ•์ˆ˜๋ฏผ"
54
+ sent = st.text_area("Text", default_value, max_chars=4, height=275)
55
+ max_length = st.sidebar.slider("์ƒ์„ฑ ๋ฌธ์žฅ ๊ธธ์ด๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”!", min_value=42, max_value=64)
56
+ temperature = st.sidebar.slider(
57
+ "Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05
58
+ )
59
+ top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
60
+ top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
61
+
62
+ print("slider sidebars rendering completed")
63
+
64
+ # make input sentence
65
+ emotion_list = ["ํ–‰๋ณต", "์ค‘๋ฆฝ", "๋ถ„๋…ธ", "ํ˜์˜ค", "๋†€๋žŒ", "์Šฌํ””", "๊ณตํฌ"]
66
+ main_emotion = st.sidebar.radio("์ฃผ์š” ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)
67
+ sub_emotion = st.sidebar.radio("๋‘ ๋ฒˆ์งธ ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)
68
+
69
+ print("radio sidebars rendering completed")
70
+
71
+ # create condition sentence
72
+ random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1)
73
+ random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1)
74
+ condition_sentence = f"{random_main_logit}๋งŒํผ {main_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. {random_sub_logit}๋งŒํผ {sub_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. "
75
+ condition_plus_input = condition_sentence + sent
76
+ print(condition_plus_input)
77
+
78
+
79
+ def infer_sentence(
80
+ condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=3
81
+ ):
82
+ encoded_prompt = tokenizer.encode(
83
+ condition_plus_input, add_special_tokens=False, return_tensors="pt"
84
+ )
85
+ if encoded_prompt.size()[-1] == 0:
86
+ input_ids = None
87
+ else:
88
+ input_ids = encoded_prompt
89
+ output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
90
+ print(output_sequences)
91
+
92
+ # exclude item that contains "unk"
93
+ output_sequences = [
94
+ output_sequence
95
+ for output_sequence in output_sequences
96
+ if "unk" not in output_sequence
97
+ ]
98
+ # choose item that length is longer than 1
99
+ output_sequences = [
100
+ output_sequence
101
+ for output_sequence in output_sequences
102
+ if len(output_sequence) > 1
103
+ ]
104
+ generated_sequence = output_sequences[0]
105
+ print(generated_sequence)
106
+
107
+ # print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
108
+ # generated_sequences = generated_sequence.tolist()
109
+ # Decode text
110
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
111
+ print(text)
112
+ # Remove all text after the stop token
113
+ stop_token = tokenizer.pad_token
114
+ print(stop_token)
115
+ text = text[: text.find(stop_token) if stop_token else None]
116
+ print(text)
117
+
118
+ condition_index = find_nth(text, "๋ฌธ์žฅ์ด๋‹ค", 2)
119
+ text = text[condition_index + 5 :]
120
+ text = text.strip()
121
+ return text
122
+
123
+
124
+ def make_residual_conditional_samhaengshi(input_letter, condition_sentence):
125
+ # make letter string into
126
+ list_samhaengshi = []
127
+
128
+ # initializing text and index for iteration purpose
129
+ index = 0
130
+
131
+ # iterating over the input letter string
132
+ for index, letter_item in enumerate(input_letter):
133
+ # initializing the input_letter
134
+ if index == 0:
135
+ residual_text = letter_item
136
+ # print('residual_text:', residual_text)
137
+
138
+ # infer and add to the output
139
+ conditional_input = f"{condition_sentence} {residual_text}"
140
+ inferred_sentence = infer_sentence(conditional_input, tokenizer)
141
+ if index != 0:
142
+ # remove previous sentence from the output
143
+ print("inferred_sentence:", inferred_sentence)
144
+ inferred_sentence = inferred_sentence.replace(
145
+ list_samhaengshi[index - 1], ""
146
+ ).strip()
147
+ else:
148
+ pass
149
+ list_samhaengshi.append(inferred_sentence)
150
+
151
+ # until the end of the input_letter, give the previous residual_text to the next iteration
152
+ if index < len(input_letter) - 1:
153
+ residual_sentence = list_samhaengshi[index]
154
+ next_letter = input_letter[index + 1]
155
+ residual_text = (
156
+ f"{residual_sentence} {next_letter}" # previous sentence + next letter
157
+ )
158
+ print("residual_text", residual_text)
159
+
160
+ elif index == len(input_letter) - 1: # end of the input_letter
161
+ # Concatenate strings in the list without intersection
162
+
163
+ return list_samhaengshi
164
+
165
+
166
+ return_text = make_residual_conditional_samhaengshi(
167
+ input_letter=sent, condition_sentence=condition_sentence
168
+ )
169
+
170
+ print(return_text)
171
+
172
+ st.write(return_text)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ streamlit
3
+ torch
4
+ numpy