|
--- |
|
license: mit |
|
language: |
|
- ru |
|
- en |
|
tags: |
|
- transformers |
|
- sentence-transformers |
|
--- |
|
|
|
# Model Card for ru-en-RoSBERTa |
|
|
|
The ru-en-RoSBERTa is a general text embedding model for Russian. The model is based on [ruRoBERTa](https://huggingface.co/ai-forever/ruRoberta-large) and fine-tuned with ~4M pairs of supervised, synthetic and unsupervised data in Russian and English. Tokenizer supports some English tokens from [RoBERTa](https://huggingface.co/FacebookAI/roberta-large) tokenizer. |
|
|
|
For more model details please refer to our [article](arxiv). |
|
|
|
## Usage |
|
|
|
The model can be used as is with prefixes. It is recommended to use CLS pooling. The choice of prefix and pooling depends on the task. |
|
|
|
We use the following basic rules to choose a prefix: |
|
- `"search_query: "` and `"search_document: "` prefixes are for answer or relevant paragraph retrieval |
|
- `"classification: "` prefix is for symmetric paraphrasing related tasks (STS, NLI, Bitext Mining) |
|
- `"clustering: "` prefix is for any tasks that rely on thematic features (topic classification, title-body retrieval) |
|
|
|
To better tailor the model to your needs, you can fine-tune it with relevant high-quality Russian and English datasets. |
|
|
|
Below are examples of texts encoding using the Transformers and SentenceTransformers libraries. |
|
|
|
### Transformers |
|
|
|
```python |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
def pool(hidden_state, mask, pooling_method="cls"): |
|
if pooling_method == "mean": |
|
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1) |
|
d = mask.sum(axis=1, keepdim=True).float() |
|
return s / d |
|
elif pooling_method == "cls": |
|
return hidden_state[:, 0] |
|
|
|
inputs = [ |
|
# |
|
"classification: Он нам и <unk> не нужон ваш Интернет!", |
|
"clustering: В Ярославской области разрешили работу бань, но без посетителей", |
|
"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?", |
|
|
|
# |
|
"classification: What a time to be alive!", |
|
"clustering: Ярославским баням разрешили работать без посетителей", |
|
"search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", |
|
] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa") |
|
model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa") |
|
|
|
tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**tokenized_inputs) |
|
|
|
embeddings = pool( |
|
outputs.last_hidden_state, |
|
tokenized_inputs["attention_mask"], |
|
pooling_method="cls" # or try "mean" |
|
) |
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
sim_scores = embeddings[:3] @ embeddings[3:].T |
|
print(sim_scores.diag().tolist()) |
|
# [0.4796873927116394, 0.9409002065658569, 0.7761015892028809] |
|
``` |
|
|
|
### SentenceTransformers |
|
|
|
```python |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
inputs = [ |
|
# |
|
"classification: Он нам и <unk> не нужон ваш Интернет!", |
|
"clustering: В Ярославской области разрешили работу бань, но без посетителей", |
|
"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?", |
|
|
|
# |
|
"classification: What a time to be alive!", |
|
"clustering: Ярославским баням разрешили работать без посетителей", |
|
"search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", |
|
] |
|
|
|
# loads model with CLS pooling |
|
model = SentenceTransformer("ai-forever/ru-en-RoSBERTa") |
|
|
|
# embeddings are normalized by default |
|
embeddings = model.encode(inputs, convert_to_tensor=True) |
|
|
|
sim_scores = embeddings[:3] @ embeddings[3:].T |
|
print(sim_scores.diag().tolist()) |
|
# [0.47968706488609314, 0.940900444984436, 0.7761018872261047] |
|
``` |
|
|
|
## Citation |
|
|
|
TODO |
|
|
|
## Limitations |
|
|
|
The model is designed to process texts in Russian, the quality in English is unknown. Maximum input text length is limited to 512 tokens. |