from urllib.request import urlopen import argparse import clip from PIL import Image import pandas as pd import time import torch from dataloader.extract_features_dataloader import transform_resize, question_preprocess from model.vqa_model import NetVQA from dataclasses import dataclass from torch.cuda.amp import autocast import gradio as gr @dataclass class InferenceConfig: ''' Describes configuration of the training process ''' model: str = "RN50x64" checkpoint_root_clip: str = "./checkpoints/clip" checkpoint_root_head: str = "./checkpoints/head" use_question_preprocess: bool = True # True: delete ? at end aux_mapping = {0: "unanswerable", 1: "unsuitable", 2: "yes", 3: "no", 4: "number", 5: "color", 6: "other"} folds = 10 tta = False # Data n_classes: int = 5726 # class mapping class_mapping: str = "./data/annotations/class_mapping.csv" device = "cuda" if torch.cuda.is_available() else "cpu" config = InferenceConfig() # load class mapping cm = pd.read_csv(config.class_mapping) classid_to_answer = {} for i in range(len(cm)): row = cm.iloc[i] classid_to_answer[row["class_id"]] = row["answer"] clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip) model = NetVQA(config).to(config.device) config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model) model_state_dict = torch.load(config.checkpoint_head) model.load_state_dict(model_state_dict, strict=True) #%% # Select Preprocessing image_transforms = transform_resize(clip_model.visual.input_resolution) if config.use_question_preprocess: question_transforms = question_preprocess else: question_transforms = None clip_model.eval() model.eval() def predict(img, text): img = Image.fromarray(img) if config.tta: image_augmentations = [] for transform in image_transforms: image_augmentations.append(transform(img)) img = torch.stack(image_augmentations, dim=0) else: img = image_transforms(img) img = img.unsqueeze(dim=0) question = question_transforms(text) question_tokens = clip.tokenize(question, truncate=True) with torch.no_grad(): img = img.to(config.device) img_feature = clip_model.encode_image(img) if config.tta: weights = torch.tensor(config.features_selection).reshape((len(config.features_selection),1)) img_feature = img_feature * weights.to(config.device) img_feature = img_feature.sum(0) img_feature = img_feature.unsqueeze(0) question_tokens = question_tokens.to(config.device) question_feature = clip_model.encode_text(question_tokens) with autocast(): output, output_aux = model(img_feature, question_feature) prediction_vqa = dict() output = output.cpu().squeeze(0) for k, v in classid_to_answer.items(): prediction_vqa[v] = float(output[k]) prediction_aux = dict() output_aux = output_aux.cpu().squeeze(0) for k, v in config.aux_mapping.items(): prediction_aux[v] = float(output_aux[k]) return prediction_vqa, prediction_aux gr.Interface(fn=predict, inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')], outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)], examples=[['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']] ).launch()