Update README.md
Browse files
README.md
CHANGED
@@ -5,6 +5,42 @@ This is the converted model from Unbabel/wmt22-cometkiwi-da
|
|
5 |
2) Renamed the keys to match the original Facebook/XLM-roberta-large
|
6 |
3) kept the layer_wise_attention / estimator layers
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
|
|
|
5 |
2) Renamed the keys to match the original Facebook/XLM-roberta-large
|
6 |
3) kept the layer_wise_attention / estimator layers
|
7 |
|
8 |
+
Because of a hack in HF's code I had to rename the "layerwise_attention.gamma" key to "layerwise_attention.gam"
|
9 |
+
|
10 |
+
I changed the config.json key "layer_transformation" from sparsemax to softmax because there is a bug in COMET since the flag is not passed, the actual function used is the default which is softmax.
|
11 |
+
|
12 |
+
Usage:
|
13 |
+
|
14 |
+
```
|
15 |
+
from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast, AutoModel
|
16 |
+
tokenizer = XLMRobertaTokenizerFast.from_pretrained("vince62s/wmt22-cometkiwi-da-roberta-large", trust_remote_code=True)
|
17 |
+
model = AutoModel.from_pretrained("vince62s/wmt22-cometkiwi-da-roberta-large", trust_remote_code=True)
|
18 |
+
|
19 |
+
text = "Hello world! </s> </s> Bonjour le monde"
|
20 |
+
encoded_text = tokenizer(text, return_tensors='pt')
|
21 |
+
print(encoded_text)
|
22 |
+
output = model(**encoded_text)
|
23 |
+
print(output[0])
|
24 |
+
|
25 |
+
{'input_ids': tensor([[ 0, 35378, 8999, 38, 2, 2, 84602, 95, 11146, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
|
26 |
+
tensor([[0.8640]], grad_fn=<AddmmBackward0>)
|
27 |
+
|
28 |
+
```
|
29 |
+
|
30 |
+
Let's double check with the original code from Unbabel Comet:
|
31 |
+
|
32 |
+
```
|
33 |
+
from comet import download_model, load_from_checkpoint
|
34 |
+
model = load_from_checkpoint("/home/vincent/Downloads/cometkiwi22/checkpoints/model.ckpt") # this is the Unbabel checkpoint
|
35 |
+
data = [{"mt": "Hello world!", "src": "Bonjour le monde"}]
|
36 |
+
output = model.predict(data, gpus=0)
|
37 |
+
print(output)
|
38 |
+
|
39 |
+
Prediction([('scores', [0.863973081111908]),
|
40 |
+
('system_score', 0.863973081111908)])
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
|
45 |
|
46 |
|