update model and README.md
Browse files- README.md +43 -13
- pytorch_model.bin +1 -1
README.md
CHANGED
@@ -10,21 +10,51 @@ license: cc-by-nc-sa 4.0
|
|
10 |
|
11 |
### How to use
|
12 |
```python
|
13 |
-
from transformers import
|
14 |
import torch
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
```
|
29 |
|
30 |
### Reference
|
|
|
10 |
|
11 |
### How to use
|
12 |
```python
|
13 |
+
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel
|
14 |
import torch
|
15 |
|
16 |
+
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
|
19 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained('byeongal/Ko-DialoGPT')
|
20 |
+
model = GPT2LMHeadModel.from_pretrained('byeongal/Ko-DialoGPT').to(device)
|
21 |
+
|
22 |
+
past_user_inputs = []
|
23 |
+
generated_responses = []
|
24 |
+
|
25 |
+
while True:
|
26 |
+
user_input = input(">> User:")
|
27 |
+
if user_input == 'bye':
|
28 |
+
break
|
29 |
+
text_idx = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
|
30 |
+
for i in range(len(generated_responses)-1, len(generated_responses)-3, -1):
|
31 |
+
if i < 0:
|
32 |
+
break
|
33 |
+
encoded_vector = tokenizer.encode(generated_responses[i] + tokenizer.eos_token, return_tensors='pt')
|
34 |
+
if text_idx.shape[-1] + encoded_vector.shape[-1] < 1000:
|
35 |
+
text_idx = torch.cat([encoded_vector, text_idx], dim=-1)
|
36 |
+
else:
|
37 |
+
break
|
38 |
+
encoded_vector = tokenizer.encode(past_user_inputs[i] + tokenizer.eos_token, return_tensors='pt')
|
39 |
+
if text_idx.shape[-1] + encoded_vector.shape[-1] < 1000:
|
40 |
+
text_idx = torch.cat([encoded_vector, text_idx], dim=-1)
|
41 |
+
else:
|
42 |
+
break
|
43 |
+
text_idx = text_idx.to(device)
|
44 |
+
inference_output = model.generate(
|
45 |
+
text_idx,
|
46 |
+
max_length=1000,
|
47 |
+
num_beams=5,
|
48 |
+
top_k=20,
|
49 |
+
no_repeat_ngram_size=4,
|
50 |
+
length_penalty=0.65,
|
51 |
+
repetition_penalty=2.0,
|
52 |
+
)
|
53 |
+
inference_output = inference_output.tolist()
|
54 |
+
bot_response = tokenizer.decode(inference_output[0][text_idx.shape[-1]:], skip_special_tokens=True)
|
55 |
+
print(f"Bot: {bot_response}")
|
56 |
+
past_user_inputs.append(user_input)
|
57 |
+
generated_responses.append(bot_response)
|
58 |
```
|
59 |
|
60 |
### Reference
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 513305211
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9603cdeb38c018ae9d3d0f9af83f12952a590a35cd87654d0a505d6cf43991fb
|
3 |
size 513305211
|