davanstrien HF staff commited on
Commit
a8221b0
·
1 Parent(s): 3bb51ac

draft pipeline

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
4
+ import requests
5
+ import asyncio
6
+ import httpx
7
+ import time
8
+ import io
9
+ from PIL import Image
10
+ import PIL
11
+
12
+ HF_MODEL_PATH = (
13
+ "ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
14
+ )
15
+
16
+ classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)
18
+
19
+ classif_pipeline = pipeline(
20
+ "image-classification", model=classif_model, feature_extractor=feature_extractor
21
+ )
22
+
23
+ OUTPUT_SENTENCE = "This image is {result}."
24
+
25
+
26
+ def load_manifest(inputs):
27
+ with requests.get(inputs) as r:
28
+ return r.json()
29
+
30
+
31
+ def get_image_urls_from_manifest(data):
32
+ image_urls = []
33
+ for sequences in data['sequences']:
34
+ for canvases in sequences['canvases']:
35
+ image_urls.extend(image['resource']['@id'] for image in canvases['images'])
36
+ return image_urls
37
+
38
+
39
+ def resize_iiif_urls(im_url, size='224'):
40
+ parts = im_url.split("/")
41
+ parts[6] = size, size
42
+ return "/".join(parts)
43
+
44
+
45
+ async def get_image(client, url):
46
+ try:
47
+ resp = await client.get(url, timeout=30)
48
+ return Image.open(io.BytesIO(resp.content))
49
+ except (PIL.UnidentifiedImageError, httpx.ReadTimeout):
50
+ return None
51
+
52
+
53
+ async def get_images(urls):
54
+ async with httpx.AsyncClient() as client:
55
+
56
+ tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
57
+ images = await asyncio.gather(*tasks)
58
+ return [image for image in images if image is not None]
59
+
60
+
61
+ def predict(inputs):
62
+ data = load_manifest(inputs)
63
+ urls = get_image_urls_from_manifest(data)
64
+ resized_urls = [resize_iiif_urls(url) for url in urls]
65
+ images = asyncio.run(get_images(resized_urls))
66
+ predicted_images = []
67
+ for image in images:
68
+ top_pred = classif_pipeline(image, top_k=1)[0]
69
+ if top_pred['label'] == 'illustrated':
70
+ predicted_images.append((image, top_pred['score']))
71
+ if len(predicted_images):
72
+ return predicted_images
73
+
74
+
75
+ demo = gr.Interface(
76
+ fn=predict,
77
+ inputs=gr.Text(),
78
+ outputs=gr.Gallery(),
79
+ title="ImageIN",
80
+ description="Identify illustrations in pages of historical books!",
81
+ )
82
+ demo.launch(debug=True, share=True)