Automatic correction of README.md metadata. Contact [email protected] for any question
bb5af96
language: ko | |
tags: | |
- gpt2 | |
- conversational | |
license: cc-by-nc-sa-4.0 | |
## Ko-DialoGPT | |
### How to use | |
```python | |
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel | |
import torch | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
tokenizer = PreTrainedTokenizerFast.from_pretrained('byeongal/Ko-DialoGPT') | |
model = GPT2LMHeadModel.from_pretrained('byeongal/Ko-DialoGPT').to(device) | |
past_user_inputs = [] | |
generated_responses = [] | |
while True: | |
user_input = input(">> User:") | |
if user_input == 'bye': | |
break | |
text_idx = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') | |
for i in range(len(generated_responses)-1, len(generated_responses)-3, -1): | |
if i < 0: | |
break | |
encoded_vector = tokenizer.encode(generated_responses[i] + tokenizer.eos_token, return_tensors='pt') | |
if text_idx.shape[-1] + encoded_vector.shape[-1] < 1000: | |
text_idx = torch.cat([encoded_vector, text_idx], dim=-1) | |
else: | |
break | |
encoded_vector = tokenizer.encode(past_user_inputs[i] + tokenizer.eos_token, return_tensors='pt') | |
if text_idx.shape[-1] + encoded_vector.shape[-1] < 1000: | |
text_idx = torch.cat([encoded_vector, text_idx], dim=-1) | |
else: | |
break | |
text_idx = text_idx.to(device) | |
inference_output = model.generate( | |
text_idx, | |
max_length=1000, | |
num_beams=5, | |
top_k=20, | |
no_repeat_ngram_size=4, | |
length_penalty=0.65, | |
repetition_penalty=2.0, | |
) | |
inference_output = inference_output.tolist() | |
bot_response = tokenizer.decode(inference_output[0][text_idx.shape[-1]:], skip_special_tokens=True) | |
print(f"Bot: {bot_response}") | |
past_user_inputs.append(user_input) | |
generated_responses.append(bot_response) | |
``` | |
### Reference | |
* [SKT-KoGPT2](https://huggingface.co/skt/kogpt2-base-v2) | |
* [KETI R&D λ°μ΄ν°](https://aihub.or.kr/opendata/keti-data/recognition-laguage/KETI-02-008) | |
* [νκ΅μ΄ λν μμ½](https://aihub.or.kr/aidata/30714) | |