# -*- coding: utf-8 -*- import numpy as np import streamlit as st from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast model_dir = "snoop2head/kogpt-conditional-2" tokenizer = PreTrainedTokenizerFast.from_pretrained( model_dir, bos_token="", eos_token="", unk_token="", pad_token="", mask_token="", ) @st.cache def load_model(model_name): model = AutoModelWithLMHead.from_pretrained(model_name) return model model = load_model(model_dir) print("loaded model completed") def find_nth(haystack, needle, n): start = haystack.find(needle) while start >= 0 and n > 1: start = haystack.find(needle, start + len(needle)) n -= 1 return start def infer(input_ids, max_length, temperature, top_k, top_p): output_sequences = model.generate( input_ids=input_ids, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, num_return_sequences=1, ) return output_sequences # prompts st.title("삼행시의 달인 KoGPT입니다 🦄") st.write("텍스트를 입력하고 CTRL+Enter(CMD+Enter)을 누르세요 🤗") # text and sidebars default_value = "박수민" sent = st.text_area("Text", default_value, max_chars=4, height=275) max_length = st.sidebar.slider("생성 문장 길이를 선택해주세요!", min_value=42, max_value=64) temperature = st.sidebar.slider( "Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05 ) top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0) top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9) print("slider sidebars rendering completed") # make input sentence emotion_list = ["행복", "중립", "분노", "혐오", "놀람", "슬픔", "공포"] main_emotion = st.sidebar.radio("주요 감정을 선택하세요", emotion_list) sub_emotion = st.sidebar.radio("두 번째 감정을 선택하세요", emotion_list) print("radio sidebars rendering completed") # create condition sentence random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1) random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1) condition_sentence = f"{random_main_logit}만큼 {main_emotion}감정인 문장이다. {random_sub_logit}만큼 {sub_emotion}감정인 문장이다. " condition_plus_input = condition_sentence + sent print(condition_plus_input) def infer_sentence( condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2 ): encoded_prompt = tokenizer.encode( condition_plus_input, add_special_tokens=False, return_tensors="pt" ) if encoded_prompt.size()[-1] == 0: input_ids = None else: input_ids = encoded_prompt output_sequences = infer(input_ids, max_length, temperature, top_k, top_p) print(output_sequences) # exclude item that contains "unk" output_sequences = [ output_sequence for output_sequence in output_sequences if "unk" not in output_sequence ] # choose item that length is longer than 1 output_sequences = [ output_sequence for output_sequence in output_sequences if len(output_sequence) > 1 ] generated_sequence = output_sequences[0] print(generated_sequence) # print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") # generated_sequences = generated_sequence.tolist() # Decode text text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) print(text) # Remove all text after the stop token stop_token = tokenizer.pad_token print(stop_token) text = text[: text.find(stop_token) if stop_token else None] print(text) condition_index = find_nth(text, "문장이다", 2) text = text[condition_index + 5 :] text = text.strip() return text def make_residual_conditional_samhaengshi(input_letter, condition_sentence): # make letter string into list_samhaengshi = [] # initializing text and index for iteration purpose index = 0 # iterating over the input letter string for index, letter_item in enumerate(input_letter): # initializing the input_letter if index == 0: residual_text = letter_item # print('residual_text:', residual_text) # infer and add to the output conditional_input = f"{condition_sentence} {residual_text}" inferred_sentence = infer_sentence(conditional_input, tokenizer) if index != 0: # remove previous sentence from the output print("inferred_sentence:", inferred_sentence) inferred_sentence = inferred_sentence.replace( list_samhaengshi[index - 1], "" ).strip() else: pass list_samhaengshi.append(inferred_sentence) # until the end of the input_letter, give the previous residual_text to the next iteration if index < len(input_letter) - 1: residual_sentence = list_samhaengshi[index] next_letter = input_letter[index + 1] residual_text = ( f"{residual_sentence} {next_letter}" # previous sentence + next letter ) print("residual_text", residual_text) elif index == len(input_letter) - 1: # end of the input_letter # Concatenate strings in the list without intersection return list_samhaengshi return_text = make_residual_conditional_samhaengshi( input_letter=sent, condition_sentence=condition_sentence ) print(return_text) st.write(return_text)