model_soups / zeroshot.py
SaraPieri
First soup!
626ec32
raw
history blame
850 Bytes
import argparse
import os
import torch
import clip
import os
from tqdm import tqdm
def zeroshot_classifier(model, classnames, templates, device):
print('Building zero-shot classifier.')
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(classnames):
texts = [template(classname) for template in templates] #format with class
texts = clip.tokenize(texts).to(device) #tokenize
class_embeddings = model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return 100*zeroshot_weights.t()