File size: 2,264 Bytes
00b5da6
5e52925
 
 
a316aa6
 
 
5e52925
 
 
 
13342bb
5e52925
 
 
13342bb
 
a5cf023
13342bb
5e52925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
---
tags:
- image-to-text
- image-captioning
language:
- ru
metrics:
- bleu
library_name: transformers
---

# First image captioning model for russian language vit-rugpt2-image-captioning

This is an image captioning model trained on translated version (en-ru) of dataset COCO2014.

# Model Details

Model was initialized `google/vit-base-patch16-224-in21k` for encoder and `sberbank-ai/rugpt3large_based_on_gpt2` for decoder.

# Metrics on test data

* Bleu: 8.672
* Bleu precision 1: 30.567
* Bleu precision 2: 7.895
* Bleu precision 3: 3.261

# Sample running code

```python

from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image

model = VisionEncoderDecoderModel.from_pretrained("vit-rugpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("vit-rugpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("vit-rugpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def predict_caption(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values.to(device)

  output_ids = model.generate(pixel_values, **gen_kwargs)

  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds

predict_caption(['train2014/COCO_train2014_000000295442.jpg']) # ['Самолет на взлетно-посадочной полосе аэропорта.']

```

# Sample running code using transformers pipeline

```python

from transformers import pipeline

image_to_text = pipeline("image-to-text", model="vit-rugpt2-image-captioning")

image_to_text("train2014/COCO_train2014_000000296754.jpg") # [{'generated_text': 'Человек идет по улице с зонтом.'}]

```


# Contact for any help
* https://huggingface.co/tuman
* https://github.com/tumanov-a
* https://t.me/tumanov_av