byeongal commited on
Commit
0502168
·
1 Parent(s): 228f755

update model and README.md

Browse files
Files changed (2) hide show
  1. README.md +43 -13
  2. pytorch_model.bin +1 -1
README.md CHANGED
@@ -10,21 +10,51 @@ license: cc-by-nc-sa 4.0
10
 
11
  ### How to use
12
  ```python
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
  import torch
15
 
16
- tokenizer = AutoTokenizer.from_pretrained("byeongal/Ko-DialoGPT")
17
- model = AutoModelForCausalLM.from_pretrained("byeongal/Ko-DialoGPT")
18
-
19
- for step in range(2):
20
- # encode the new user input, add the eos_token and return a tensor in Pytorch
21
- new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
22
- # append the new user input tokens to the chat history
23
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
24
- # generated a response while limiting the total chat history to 1000 tokens,
25
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
26
- # pretty print last ouput tokens from bot
27
- print("BOT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ```
29
 
30
  ### Reference
 
10
 
11
  ### How to use
12
  ```python
13
+ from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel
14
  import torch
15
 
16
+
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('byeongal/Ko-DialoGPT')
20
+ model = GPT2LMHeadModel.from_pretrained('byeongal/Ko-DialoGPT').to(device)
21
+
22
+ past_user_inputs = []
23
+ generated_responses = []
24
+
25
+ while True:
26
+ user_input = input(">> User:")
27
+ if user_input == 'bye':
28
+ break
29
+ text_idx = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
30
+ for i in range(len(generated_responses)-1, len(generated_responses)-3, -1):
31
+ if i < 0:
32
+ break
33
+ encoded_vector = tokenizer.encode(generated_responses[i] + tokenizer.eos_token, return_tensors='pt')
34
+ if text_idx.shape[-1] + encoded_vector.shape[-1] < 1000:
35
+ text_idx = torch.cat([encoded_vector, text_idx], dim=-1)
36
+ else:
37
+ break
38
+ encoded_vector = tokenizer.encode(past_user_inputs[i] + tokenizer.eos_token, return_tensors='pt')
39
+ if text_idx.shape[-1] + encoded_vector.shape[-1] < 1000:
40
+ text_idx = torch.cat([encoded_vector, text_idx], dim=-1)
41
+ else:
42
+ break
43
+ text_idx = text_idx.to(device)
44
+ inference_output = model.generate(
45
+ text_idx,
46
+ max_length=1000,
47
+ num_beams=5,
48
+ top_k=20,
49
+ no_repeat_ngram_size=4,
50
+ length_penalty=0.65,
51
+ repetition_penalty=2.0,
52
+ )
53
+ inference_output = inference_output.tolist()
54
+ bot_response = tokenizer.decode(inference_output[0][text_idx.shape[-1]:], skip_special_tokens=True)
55
+ print(f"Bot: {bot_response}")
56
+ past_user_inputs.append(user_input)
57
+ generated_responses.append(bot_response)
58
  ```
59
 
60
  ### Reference
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6dec0cdd37c1d0af26f80bacf0c692e2fdd8045be4617042920d2ffa6104b4df
3
  size 513305211
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9603cdeb38c018ae9d3d0f9af83f12952a590a35cd87654d0a505d6cf43991fb
3
  size 513305211