First image captioning model for russian language vit-rugpt2-image-captioning

This is an image captioning model trained on translated version (en-ru) of dataset COCO2014.

Model Details

Model was initialized google/vit-base-patch16-224-in21k for encoder and sberbank-ai/rugpt3large_based_on_gpt2 for decoder.

Metrics on test data

  • Bleu: 8.672
  • Bleu precision 1: 30.567
  • Bleu precision 2: 7.895
  • Bleu precision 3: 3.261

Sample running code


from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image

model = VisionEncoderDecoderModel.from_pretrained("vit-rugpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("vit-rugpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("vit-rugpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def predict_caption(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values.to(device)

  output_ids = model.generate(pixel_values, **gen_kwargs)

  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds

predict_caption(['train2014/COCO_train2014_000000295442.jpg']) # ['Самолет на взлетно-посадочной полосе аэропорта.']

Sample running code using transformers pipeline


from transformers import pipeline

image_to_text = pipeline("image-to-text", model="vit-rugpt2-image-captioning")

image_to_text("train2014/COCO_train2014_000000296754.jpg") # [{'generated_text': 'Человек идет по улице с зонтом.'}]

Contact for any help

Downloads last month
269
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Spaces using tuman/vit-rugpt2-image-captioning 3