|
import argparse |
|
import os |
|
import torch |
|
import clip |
|
import os |
|
from tqdm import tqdm |
|
|
|
def zeroshot_classifier(model, classnames, templates, device): |
|
print('Building zero-shot classifier.') |
|
with torch.no_grad(): |
|
zeroshot_weights = [] |
|
for classname in tqdm(classnames): |
|
texts = [template(classname) for template in templates] |
|
texts = clip.tokenize(texts).to(device) |
|
class_embeddings = model.encode_text(texts) |
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
|
class_embedding = class_embeddings.mean(dim=0) |
|
class_embedding /= class_embedding.norm() |
|
zeroshot_weights.append(class_embedding) |
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) |
|
return 100*zeroshot_weights.t() |
|
|
|
|
|
|