Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,135 @@
|
|
1 |
---
|
2 |
license: mit
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
+
language:
|
4 |
+
- it
|
5 |
---
|
6 |
+
--------------------------------------------------------------------------------------------------
|
7 |
+
|
8 |
+
<body>
|
9 |
+
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
|
10 |
+
<br>
|
11 |
+
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> </span>
|
12 |
+
<br>
|
13 |
+
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;"> Model: DIABLO 🔥</span>
|
14 |
+
<br>
|
15 |
+
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;"> Lang: IT</span>
|
16 |
+
<br>
|
17 |
+
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> </span>
|
18 |
+
<br>
|
19 |
+
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
|
20 |
+
</body>
|
21 |
+
|
22 |
+
--------------------------------------------------------------------------------------------------
|
23 |
+
|
24 |
+
<h3>Model description</h3>
|
25 |
+
|
26 |
+
This model is a <b>conversational</b> language model for the <b>Italian</b> language, based on a GPT-like <b>[1]</b> architecture (more specifically, the model has been obtained by modifying Meta's XGLM architecture <b>[2]</b> and exploiting its 1.7B checkpoint).
|
27 |
+
|
28 |
+
The model has been trained on a corpus of \~50K Italian conversational exchanges for \~3 epochs (\~15K steps with a batch size of 10), using 3 different learning rates (1e-5, 2e-6, 1e-6) and exploiting FP16 quantization to manage the considerable size of the model.
|
29 |
+
The training corpus has been built by using Meta's Blenderbot <b>[3]</b> to generate 50K conversational exchanges in English, and then translating them to the Italian language using a machine traslation model.
|
30 |
+
|
31 |
+
The current release is designed for brief and informal conversations (small talk) covering light topics (mainly food, entertainment and holidays), but several generalizations and improvements will be introduced in future releases.
|
32 |
+
|
33 |
+
|
34 |
+
<h3>Quick usage</h3>
|
35 |
+
|
36 |
+
In order to use the model for inference, the following pipeline is needed:
|
37 |
+
|
38 |
+
```python
|
39 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
40 |
+
import torch
|
41 |
+
import re
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained("osiria/diablo-italian-chatbot-1.3b")
|
44 |
+
model = AutoModelForCausalLM.from_pretrained("osiria/diablo-italian-chatbot-1.3b")
|
45 |
+
device = torch.device("cpu")
|
46 |
+
model = model.to(device)
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
class Diablo:
|
50 |
+
|
51 |
+
def __init__(self, tokenizer, model):
|
52 |
+
self.tokenizer = tokenizer
|
53 |
+
self.model = model
|
54 |
+
|
55 |
+
def _check_sublist(self, lst, sub_lst, sep = " "):
|
56 |
+
|
57 |
+
l_type = type(lst[0])
|
58 |
+
lst = sep.join(list(map(str, lst)))
|
59 |
+
sub_lst = sep.join(list(map(str, sub_lst)))
|
60 |
+
|
61 |
+
return sub_lst in lst
|
62 |
+
|
63 |
+
def _exclude_sublist(self, lst, sub_lst, sep = " "):
|
64 |
+
|
65 |
+
l_type = type(lst[0])
|
66 |
+
lst = sep.join(list(map(str, lst)))
|
67 |
+
sub_lst = sep.join(list(map(str, sub_lst)))
|
68 |
+
lst = re.sub("\s+", " ", lst.replace(sub_lst, "")).strip().split(sep)
|
69 |
+
lst = list(map(l_type, lst))
|
70 |
+
|
71 |
+
return lst
|
72 |
+
|
73 |
+
def generate(self, prompt, sep = "|", max_tokens = 100, excluded = [[40, 19]],
|
74 |
+
lookback = 1, stop_tokens = [5, 27, 33], sample = False, top_k = 3):
|
75 |
+
|
76 |
+
tokens = tokenizer.encode(prompt + sep)
|
77 |
+
tokens_generated = []
|
78 |
+
while tokens[-1] not in stop_tokens and len(tokens) < max_tokens:
|
79 |
+
output = model.forward(input_ids=torch.tensor([tokens]).to(device)).logits[0,-1]
|
80 |
+
output = torch.softmax(output, dim = 0)
|
81 |
+
candidates = torch.topk(output, k = top_k)
|
82 |
+
if sample:
|
83 |
+
indices = candidates.indices
|
84 |
+
scores = candidates.values
|
85 |
+
next_token = indices[torch.multinomial(scores, 1)[0].item()]
|
86 |
+
else:
|
87 |
+
next_token = candidates.indices[0]
|
88 |
+
next_token = next_token.item()
|
89 |
+
sub_tokens = tokens_generated[-lookback:] + [next_token]
|
90 |
+
if len(tokens_generated) >= (lookback + 1) and next_token in tokens_generated[-(lookback + 1):]:
|
91 |
+
next_token = candidates.indices[1]
|
92 |
+
next_token = next_token.item()
|
93 |
+
elif len(tokens_generated) >= lookback and self._check_sublist(tokens_generated, sub_tokens):
|
94 |
+
next_token = candidates.indices[1]
|
95 |
+
next_token = next_token.item()
|
96 |
+
tokens = tokens + [next_token]
|
97 |
+
tokens_generated = tokens_generated + [next_token]
|
98 |
+
for ex_lst in excluded:
|
99 |
+
tokens = self._exclude_sublist(tokens, ex_lst)
|
100 |
+
output = tokenizer.decode(tokens, skip_special_tokens=True)
|
101 |
+
output = output.split(sep)[-1].strip()
|
102 |
+
output = output[0].upper() + output[1:]
|
103 |
+
if output[-1] == tokenizer.decode(stop_tokens[0]):
|
104 |
+
output = output[:-1]
|
105 |
+
|
106 |
+
return output
|
107 |
+
|
108 |
+
diablo = Diablo(tokenizer = tokenizer, model = model)
|
109 |
+
|
110 |
+
prompt = "Ciao, come stai?"
|
111 |
+
|
112 |
+
# setting "sample = True" the model will be more creative but occasionally less accurate
|
113 |
+
print("OUTPUT:", diablo.generate(prompt, sample = False))
|
114 |
+
|
115 |
+
# OUTPUT: Sto bene, grazie
|
116 |
+
```
|
117 |
+
|
118 |
+
|
119 |
+
<h3>Limitations</h3>
|
120 |
+
|
121 |
+
This model has been mainly trained on machine-translated (and synthetic) conversational data, so it might behave erratically when presented with prompts which are too far away from its training set.
|
122 |
+
Moreover, the heterogeneous nature of the pretraining dataset, together with the limits of the conversational data, might lead the model to produce biased or offensive content with respect to gender, race, ideologies, and political or religious beliefs.
|
123 |
+
These limitations imply that the model and its outputs should be used with caution, and should not be involved in situations that require the generated text to be fair or true.
|
124 |
+
|
125 |
+
<h3>References</h3>
|
126 |
+
|
127 |
+
[1] https://arxiv.org/abs/2005.14165
|
128 |
+
|
129 |
+
[2] https://arxiv.org/abs/2112.10668
|
130 |
+
|
131 |
+
[3] https://arxiv.org/pdf/2004.13637.pdf
|
132 |
+
|
133 |
+
<h3>License</h3>
|
134 |
+
|
135 |
+
The model is released under <b>MIT</b> license
|