logo_generator / dev /data /CC3M_downloader.py
boris's picture
chore: move files around
31da1e5
raw
history blame
2.3 kB
'''
This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
Find them here- [https://github.com/google-research-datasets/conceptual-captions]
'''
import sys
import os
from datetime import datetime
import pandas as pd
import contexttimer
from urllib.request import urlopen
import requests
from PIL import Image
import torch
from torchvision.transforms import functional as TF
from multiprocessing import Pool
from tqdm import tqdm
import logging
import sys
# Setup
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
if len(sys.argv) != 3:
print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
exit(1)
# Load data
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
with contexttimer.Timer(prefix="Loading from tsv"):
df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
print(f'Loaded {len(url_to_idx_map)} urls')
base_dir = os.path.join(os.getcwd(), sys.argv[2])
def process(item):
url, image_id = item
try:
base_url = os.path.basename(url) # extract base url
stem, ext = os.path.splitext(base_url) # split into stem and extension
filename = f'{image_id:08d}---{stem}.jpg' # create filename
filepath = os.path.join(base_dir, filename) # concat to get filepath
if not os.path.isfile(filepath):
req = requests.get(url, stream=True, timeout=1, verify=False).raw
image = Image.open(req).convert('RGB')
if min(image.size) > 512:
image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
image.save(filepath) # save PIL image
except Exception as e:
logging.info(" ".join(repr(e).splitlines()))
logging.error(url)
list_of_items = list(url_to_idx_map.items())
print(len(list_of_items))
with Pool(128) as p:
r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
print('DONE')