File size: 1,870 Bytes
a166479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from collections import OrderedDict
import sys
import torch
from torch import nn
from torch.nn import functional as F
from bert.modeling_bert import BertModel
class _LAVTSimpleDecode(nn.Module):
def __init__(self, backbone, classifier):
super(_LAVTSimpleDecode, self).__init__()
self.backbone = backbone
self.classifier = classifier
def forward(self, x, l_feats, l_mask):
input_shape = x.shape[-2:]
features = self.backbone(x, l_feats, l_mask)
x_c1, x_c2, x_c3, x_c4 = features
x = self.classifier(x_c4, x_c3, x_c2, x_c1)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
return x
class LAVT(_LAVTSimpleDecode):
pass
###############################################
# LAVT One: put BERT inside the overall model #
###############################################
class _LAVTOneSimpleDecode(nn.Module):
def __init__(self, backbone, classifier, args):
super(_LAVTOneSimpleDecode, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.text_encoder = BertModel.from_pretrained(args.ck_bert)
self.text_encoder.pooler = None
def forward(self, x, text, l_mask):
input_shape = x.shape[-2:]
### language inference ###
l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768)
l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1)
##########################
features = self.backbone(x, l_feats, l_mask)
x_c1, x_c2, x_c3, x_c4 = features
x = self.classifier(x_c4, x_c3, x_c2, x_c1)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
return x
class LAVTOne(_LAVTOneSimpleDecode):
pass
|