|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- atasoglu/flickr8k-dataset |
|
language: |
|
- en |
|
metrics: |
|
- rouge |
|
pipeline_tag: image-to-text |
|
tags: |
|
- image |
|
- vision |
|
--- |
|
|
|
Vision Encoder Decoder (ViT + BERT) model that fine-tuned on [flickr8k-dataset](https://huggingface.co/datasets/atasoglu/flickr8k-dataset) for image-to-text task. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, BertTokenizer |
|
import torch |
|
from PIL import Image |
|
|
|
# load models |
|
feature_extractor = ViTImageProcessor.from_pretrained("atasoglu/vit-bert-flickr8k") |
|
tokenizer = BertTokenizer.from_pretrained("atasoglu/vit-bert-flickr8k") |
|
model = VisionEncoderDecoderModel.from_pretrained("atasoglu/vit-bert-flickr8k") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
# load image |
|
img = Image.open("example.jpg") |
|
|
|
# encode (extracting features) |
|
pixel_values = feature_extractor(images=[img], return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(device) |
|
|
|
# generate caption |
|
output_ids = model.generate(pixel_values) |
|
|
|
# decode |
|
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
print(preds) |
|
``` |