adalbertojunior commited on
Commit
f8bd957
·
1 Parent(s): 587e50d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
3
+ import requests
4
+ from PIL import Image
5
+ import torch
6
+
7
+
8
+ CHECKPOINT = "adalbertojunior/image_captioning_portuguese"
9
+
10
+ @st.cache
11
+ def get_model():
12
+ model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT)
13
+ return model
14
+
15
+
16
+ feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT)
17
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
18
+
19
+ st.title("Image Captioning with ViT & GPT2 🇧🇷")
20
+
21
+ st.sidebar.markdown("## Generation parameters")
22
+ max_length = st.sidebar.number_input("Max length", value=20, min_value=1)
23
+ no_repeat_ngram_size = st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1)
24
+ num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1)
25
+
26
+ gen_mode = st.sidebar.selectbox("Generation mode", ["beam search", "sampling"])
27
+ if gen_mode == "beam search":
28
+ num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1)
29
+ early_stopping = st.sidebar.checkbox("Early stopping", value=True)
30
+ gen_params = {
31
+ "num_beams": num_beams,
32
+ "early_stopping": early_stopping
33
+ }
34
+ elif gen_mode == "sampling":
35
+ do_sample = True
36
+ top_k = st.sidebar.number_input("top_k", value=30, min_value=0)
37
+ top_p = st.sidebar.number_input("top_p", value=0, min_value=0)
38
+ temperature = st.sidebar.number_input("temperature", value=0.7, min_value=0.0)
39
+ gen_params = {
40
+ "do_sample": do_sample,
41
+ "top_k": top_k,
42
+ "top_p": top_p,
43
+ "temperature": temperature
44
+ }
45
+
46
+ def generate_caption(url):
47
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
48
+ inputs = feature_extractor(image, return_tensors="pt")
49
+ model = get_model()
50
+ model.eval()
51
+ generated_ids = model.generate(
52
+ inputs["pixel_values"],
53
+ max_length=20,
54
+ no_repeat_ngram_size=2,
55
+ num_return_sequences=3,
56
+ **gen_params
57
+ )
58
+ captions = tokenizer.batch_decode(
59
+ generated_ids,
60
+ skip_special_tokens=True,
61
+ )
62
+ return captions[0]
63
+
64
+
65
+ url = st.text_input(
66
+ "Insert your URL", "https://iheartcats.com/wp-content/uploads/2015/08/c84.jpg"
67
+ )
68
+
69
+ st.image(url)
70
+
71
+ if st.button("Run captioning"):
72
+ with st.spinner("Processing image..."):
73
+ caption = generate_caption(url)
74
+ st.text(caption)