import torch import caption_model from transformers import BertTokenizer import torchvision from PIL import Image from configuration import Config import numpy as np def under_max(image): if image.mode != 'RGB': image = image.convert("RGB") shape = np.array(image.size, dtype=np.float) long_dim = max(shape) scale = 299 / long_dim new_shape = (shape * scale).astype(int) image = image.resize(new_shape) return image class Model(object): def __init__(self, gpu=0): config = Config() config.device = 'cuda:{}'.format(gpu) model, _ = caption_model.build_model(config) checkpoint = torch.load('./checkpoint.pth', map_location='cpu') model.load_state_dict(checkpoint['model']) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token) end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token) self.caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long).to(config.device) self.cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool).to(config.device) self.caption[:, 0] = start_token self.cap_mask[:, 0] = False self.val_transform = torchvision.transforms.Compose([ torchvision.transforms.Lambda(under_max), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) model.to(config.device) self.model = model self.config = config self.tokenizer = tokenizer def evaluate(self, im): self.model.eval() for i in range(self.config.max_position_embeddings - 1): predictions = self.model(im.to(self.config.device), self.caption.to(self.config.device), self.cap_mask.to(self.config.device)) predictions = predictions[:, i, :] predicted_id = torch.argmax(predictions, axis=-1).to(self.config.device) if predicted_id[0] == 102: return self.caption self.caption[:, i+1] = predicted_id[0] self.cap_mask[:, i+1] = False return caption def predict(self, image_path): image = Image.open(image_path) image = self.val_transform(image) image = image.unsqueeze(0) output = self.evaluate(image) return self.tokenizer.decode(output[0].tolist(), skip_special_tokens=True) if __name__ == "__main__": model = Model() result = model.predict("./image.jpg") print(result)