|
import torch |
|
import torch.nn as nn |
|
import clip |
|
from model import ST_GCN_18 |
|
|
|
class ContrastiveModule(nn.Module): |
|
|
|
def __init__(self, args): |
|
super(ContrastiveModule, self).__init__() |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
del model.visual |
|
self.model = model |
|
|
|
base_channel = 3 |
|
base_channel = base_channel * 2 if args.gyro else base_channel |
|
base_channel = base_channel * 2 if args.stft else base_channel |
|
self.model.acc = ST_GCN_18(in_channels=base_channel) |
|
|
|
self.model = self.model.float() |
|
|
|
if args.stage == 'finetune': |
|
self.fc = nn.Linear(512, args.num_class) |
|
|
|
def encode_image(self, image): |
|
return self.model.acc(image.float()).squeeze(-1).squeeze(-1) |
|
|
|
def encode_text(self, text): |
|
x = self.model.token_embedding(text).float() |
|
x = x + self.model.positional_embedding.float() |
|
x = x.permute(1, 0, 2) |
|
x = self.model.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
x = self.model.ln_final(x).float() |
|
|
|
|
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection |
|
|
|
return x |
|
|
|
def classifier(self, image): |
|
|
|
imu_features = self.model.acc(image.float()).squeeze(-1).squeeze(-1) |
|
out = self.fc(imu_features) |
|
return out |
|
|
|
def forward(self, inputs_imu, inputs_text): |
|
|
|
imu_features = self.encode_image(inputs_imu) |
|
text_features = self.encode_text(inputs_text) |
|
|
|
|
|
imu_features = imu_features / imu_features.norm(dim=-1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
logit_scale = self.model.logit_scale.exp() |
|
logits_per_image = logit_scale * imu_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
return logits_per_image, logits_per_text |