|
--- |
|
language: |
|
- ru |
|
tags: |
|
- causal-lm |
|
- text-generation |
|
license: |
|
- apache-2.0 |
|
inference: false |
|
widget: |
|
- text: "Как обрести просветление?<s>" |
|
example_title: "Википедия" |
|
--- |
|
# RuGPT3Medium-tathagata |
|
## Model description |
|
This is the model for text generation for Russian based on [rugpt3medium_based_on_gpt2](https://huggingface.co/sberbank-ai/rugpt3medium_based_on_gpt2). |
|
## Intended uses & limitations |
|
#### How to use |
|
Тhis model was trained and run to generate text on RTX 3080 |
|
```python |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
DEVICE = torch.device("cuda:0") |
|
|
|
model_name_or_path = "sberbank-ai/rugpt3medium_based_on_gpt2" |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) |
|
model = GPT2LMHeadModel.from_pretrained('model').to(DEVICE) |
|
text = "В чем смысл жизни?\n" |
|
input_ids = tokenizer.encode(text, return_tensors="pt").to(DEVICE) |
|
model.eval() |
|
with torch.no_grad(): |
|
out = model.generate(input_ids, |
|
do_sample=True, |
|
num_beams=4, |
|
temperature=1.1, |
|
top_p=0.9, |
|
top_k=50, |
|
max_length=250, |
|
min_length=50, |
|
early_stopping=True, |
|
num_return_sequences=3, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
generated_text = list(map(tokenizer.decode, out))[0] |
|
print() |
|
print(generated_text) |
|
``` |
|
|
|
## Dataset |
|
- Dataset: [tathagata](https://huggingface.co/datasets/radm/tathagata) |
|
|