tolga-ozturk
commited on
Commit
·
b12d7ce
1
Parent(s):
7b1d190
Update Readme
Browse files
README.md
CHANGED
@@ -7,9 +7,9 @@ tags:
|
|
7 |
- next-sentence-prediction
|
8 |
- gpt
|
9 |
datasets:
|
10 |
-
- wikipedia
|
11 |
metrics:
|
12 |
-
- accuracy
|
13 |
---
|
14 |
|
15 |
# mGPT-nsp
|
@@ -18,12 +18,12 @@ mGPT-nsp is fine-tuned for Next Sentence Prediction task on the [wikipedia datas
|
|
18 |
|
19 |
## Model description
|
20 |
|
21 |
-
mGPT-nsp is a Transformer-based model which fine-tuned for Next Sentence Prediction task on 11000 English and 11000 German Wikipedia articles.
|
22 |
|
23 |
## Intended uses
|
24 |
|
25 |
- Apply Next Sentence Prediction tasks. (compare the results with BERT models since BERT natively supports this task)
|
26 |
-
- See how to fine-tune
|
27 |
- Check our [paper](https://arxiv.org/abs/2307.07331) to see its results
|
28 |
|
29 |
## How to use
|
@@ -40,26 +40,23 @@ class ModelNSP(torch.nn.Module):
|
|
40 |
def __init__(self, pretrained_model="THUMT/mGPT"):
|
41 |
super(ModelNSP, self).__init__()
|
42 |
self.core_model = GPT2Model.from_pretrained(pretrained_model)
|
43 |
-
|
44 |
-
|
45 |
|
46 |
def forward(self, input_ids, attention_mask=None):
|
47 |
-
|
48 |
-
return self.nsp_head(core_model_outputs).softmax(dim=-1)
|
49 |
|
50 |
-
weights = torch.load(hf_hub_download(repo_id="tolga-ozturk/mGPT-nsp", filename="model_weights.bin"))
|
51 |
model = torch.nn.DataParallel(ModelNSP().eval())
|
52 |
-
model.load_state_dict(
|
53 |
tokenizer = MT5Tokenizer.from_pretrained("tolga-ozturk/mGPT-nsp")
|
54 |
```
|
55 |
|
56 |
### Inference
|
57 |
```python
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
print(
|
62 |
-
print(torch.argmax(outputs, dim=-1))
|
63 |
```
|
64 |
|
65 |
## BibTeX entry and citation info
|
@@ -74,4 +71,4 @@ print(torch.argmax(outputs, dim=-1))
|
|
74 |
}
|
75 |
```
|
76 |
|
77 |
-
The work is done with Ludwig-Maximilians-Universität Statistics group, don't forget to check out [their huggingface page](https://huggingface.co/misoda) for other interesting works!
|
|
|
7 |
- next-sentence-prediction
|
8 |
- gpt
|
9 |
datasets:
|
10 |
+
- wikipedia
|
11 |
metrics:
|
12 |
+
- accuracy
|
13 |
---
|
14 |
|
15 |
# mGPT-nsp
|
|
|
18 |
|
19 |
## Model description
|
20 |
|
21 |
+
mGPT-nsp is a Transformer-based model which was fine-tuned for Next Sentence Prediction task on 11000 English and 11000 German Wikipedia articles. We use the same tokenization and vocabulary as the [mT5 model](https://huggingface.co/google/mt5-base).
|
22 |
|
23 |
## Intended uses
|
24 |
|
25 |
- Apply Next Sentence Prediction tasks. (compare the results with BERT models since BERT natively supports this task)
|
26 |
+
- See how to fine-tune an mGPT2 model using our [code](https://github.com/slds-lmu/stereotypes-multi/tree/main)
|
27 |
- Check our [paper](https://arxiv.org/abs/2307.07331) to see its results
|
28 |
|
29 |
## How to use
|
|
|
40 |
def __init__(self, pretrained_model="THUMT/mGPT"):
|
41 |
super(ModelNSP, self).__init__()
|
42 |
self.core_model = GPT2Model.from_pretrained(pretrained_model)
|
43 |
+
self.nsp_head = torch.nn.Sequential(torch.nn.Linear(self.core_model.config.hidden_size, 300),
|
44 |
+
torch.nn.Linear(300, 300), torch.nn.Linear(300, 2))
|
45 |
|
46 |
def forward(self, input_ids, attention_mask=None):
|
47 |
+
return self.nsp_head(self.core_model(input_ids, attention_mask=attention_mask)[0].mean(dim=1)).softmax(dim=-1)
|
|
|
48 |
|
|
|
49 |
model = torch.nn.DataParallel(ModelNSP().eval())
|
50 |
+
model.load_state_dict(torch.load(hf_hub_download(repo_id="tolga-ozturk/mGPT-nsp", filename="model_weights.bin")))
|
51 |
tokenizer = MT5Tokenizer.from_pretrained("tolga-ozturk/mGPT-nsp")
|
52 |
```
|
53 |
|
54 |
### Inference
|
55 |
```python
|
56 |
+
batch_texts = [("In Italy, pizza is presented unsliced.", "The sky is blue."),
|
57 |
+
("In Italy, pizza is presented unsliced.", "However, it is served sliced in Turkey.")]
|
58 |
+
encoded_dict = tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_texts, truncation="longest_first",padding=True, return_tensors="pt", return_attention_mask=True, max_length=256)
|
59 |
+
print(torch.argmax(model(encoded_dict.input_ids, attention_mask=encoded_dict.attention_mask), dim=-1))
|
|
|
60 |
```
|
61 |
|
62 |
## BibTeX entry and citation info
|
|
|
71 |
}
|
72 |
```
|
73 |
|
74 |
+
The work is done with Ludwig-Maximilians-Universität Statistics group, don't forget to check out [their huggingface page](https://huggingface.co/misoda) for other interesting works!
|