1
init
75f07f8
raw
history blame
20.4 kB
# copy from https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct/blob/main/scripts/eval_mteb.py
#### ATTENTION ####
# To Reproduce the results of Sparse and Dense + Sparse, you need to hack the MTEB RetrievalEvaluator
# in mteb/evaluation/evaluators/RetrievalEvaluator.py
# class RetrievalEvaluator(Evaluator):
# def __init__(
# self,
# retriever=None,
# task_name: str | None = None,
# k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
# score_function: str = "cos_sim",
# encode_kwargs: dict[str, Any] = {},
# **kwargs,
# ):
# you need to change default score_function to "dot" to reproduce the results of Sparse and Dense + Sparse
MODE = "Dense" # "Dense" or "Sparse" or "Dense + Sparse"
TASK_LIST_CLASSIFICATION = [
"AmazonCounterfactualClassification",
"AmazonPolarityClassification",
"AmazonReviewsClassification",
"Banking77Classification",
"EmotionClassification",
"ImdbClassification",
"MassiveIntentClassification",
"MassiveScenarioClassification",
"MTOPDomainClassification",
"MTOPIntentClassification",
"ToxicConversationsClassification",
"TweetSentimentExtractionClassification",
]
TASK_LIST_CLUSTERING = [
"ArxivClusteringP2P",
"ArxivClusteringS2S",
"BiorxivClusteringP2P",
"BiorxivClusteringS2S",
"MedrxivClusteringP2P",
"MedrxivClusteringS2S",
"RedditClustering",
"RedditClusteringP2P",
"StackExchangeClustering",
"StackExchangeClusteringP2P",
"TwentyNewsgroupsClustering",
]
TASK_LIST_PAIR_CLASSIFICATION = [
"SprintDuplicateQuestions",
"TwitterSemEval2015",
"TwitterURLCorpus",
]
TASK_LIST_RERANKING = [
"AskUbuntuDupQuestions",
"MindSmallReranking",
"SciDocsRR",
"StackOverflowDupQuestions",
]
TASK_LIST_RETRIEVAL = [
"ArguAna",
"FiQA2018",
"QuoraRetrieval",
"SCIDOCS",
"SciFact",
"Touche2020",
"TRECCOVID",
"NFCorpus",
"NQ",
"ClimateFEVER",
"CQADupstackAndroidRetrieval",
"CQADupstackEnglishRetrieval",
"CQADupstackGamingRetrieval",
"CQADupstackGisRetrieval",
"CQADupstackMathematicaRetrieval",
"CQADupstackPhysicsRetrieval",
"CQADupstackProgrammersRetrieval",
"CQADupstackStatsRetrieval",
"CQADupstackTexRetrieval",
"CQADupstackUnixRetrieval",
"CQADupstackWebmastersRetrieval",
"CQADupstackWordpressRetrieval",
"DBPedia",
"HotpotQA",
"MSMARCO",
"FEVER",
]
TASK_LIST_STS = [
"BIOSSES",
"SICK-R",
"STS12",
"STS13",
"STS14",
"STS15",
"STS16",
"STS17",
"STS22",
"STSBenchmark",
"SummEval",
]
MTEB_TASK_LIST = (
TASK_LIST_RETRIEVAL
+ TASK_LIST_CLASSIFICATION
+ TASK_LIST_CLUSTERING
+ TASK_LIST_PAIR_CLASSIFICATION
+ TASK_LIST_RERANKING
+ TASK_LIST_STS
)
CMTEB_TASK_LIST = [
"TNews",
"IFlyTek",
"MultilingualSentiment",
"JDReview",
"OnlineShopping",
"Waimai",
"AmazonReviewsClassification",
"MassiveIntentClassification",
"MassiveScenarioClassification",
"MultilingualSentiment",
"CLSClusteringS2S",
"CLSClusteringP2P",
"ThuNewsClusteringS2S",
"ThuNewsClusteringP2P",
"Ocnli",
"Cmnli",
"T2Reranking",
"MMarcoReranking",
"CMedQAv1-reranking",
"CMedQAv2-reranking",
"T2Retrieval",
"MMarcoRetrieval",
"DuRetrieval",
"CovidRetrieval",
"CmedqaRetrieval",
"EcomRetrieval",
"MedicalRetrieval",
"VideoRetrieval",
"ATEC",
"BQ",
"LCQMC",
"PAWSX",
"STSB",
"AFQMC",
"QBQTC",
"STS22",
]
MTEB_TASK_LIST = CMTEB_TASK_LIST + MTEB_TASK_LIST
import torch
import torch.nn.functional as F
import tqdm
import numpy as np
import math
from functools import partial
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding
from transformers.modeling_outputs import BaseModelOutput
from typing import List, Dict
from mteb import MTEB
def get_detailed_instruct(task_description: str) -> str:
if not task_description:
return ""
return "Instruction: {} Query: ".format(task_description)
def get_task_def_by_task_name_and_type(
task_name: str,
task_type: str,
default_instruct="",
):
if task_type in ["STS"]:
return None
if task_type in ["Summarization"]:
return "Given a news summary, retrieve other semantically similar summaries"
if task_type in ["Classification"]:
task_name_to_instruct: Dict[str, str] = {
"AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual.",
"AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment.",
"AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category.",
"Banking77Classification": "Given a online banking query, find the corresponding intents.",
"EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.",
"ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset.",
"MassiveIntentClassification": "Given a user utterance as query, find the user intents.",
"MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios.",
"MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation.",
"MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation.",
"ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic.",
"TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral.",
# C-MTEB eval instructions
"TNews": "根据标题确定新闻的类别。",
"IFlyTek": "根据描述确定APP的类别。",
"MultilingualSentiment": "将亚马逊评论分为积极、消极或中立情绪。",
"JDReview": "将商品评论分为积极或消极情绪。",
"OnlineShopping": "将商品评论分为积极或消极情绪。",
"Waimai": "将外卖评论分为积极或消极情绪。",
}
return task_name_to_instruct.get(task_name,None)
if task_type in ["Clustering"]:
task_name_to_instruct: Dict[str, str] = {
"ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts.",
"ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles.",
"BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts.",
"BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles.",
"MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts.",
"MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles.",
"RedditClustering": "Identify the topic or theme of Reddit posts based on the titles.",
"RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts.",
"StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles.",
"StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs.",
"TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles.",
# C-MTEB eval instructions
"CLSClusteringS2S": "根据标题确定文章的类别。",
"CLSClusteringP2P": "根据标题和摘要确定文章的类别。",
"ThuNewsClusteringS2S": "根据标题确定新闻的类别。",
"ThuNewsClusteringP2P": "根据标题和摘要确定新闻的类别。",
}
return task_name_to_instruct.get(task_name,None)
if task_type in ["Reranking", "PairClassification"]:
task_name_to_instruct: Dict[str, str] = {
"AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum.",
"MindSmallReranking": "Retrieve relevant news articles based on user browsing history.",
"SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers.",
"StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum.",
"SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum.",
"TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet.",
"TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet.",
# C-MTEB eval instructions
"T2Reranking": "为这个问题检索相关段落。",
"MMarcoReranking": "为这个查询检索相关段落。",
"CMedQAv1-reranking": "为这个医疗问题检索相关回答。",
"CMedQAv2-reranking": "为这个医疗问题检索相关回答。",
}
return task_name_to_instruct.get(task_name,None)
if task_type in ["Retrieval"]:
if task_name.lower().startswith("cqadupstack"):
return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question"
task_name_to_instruct: Dict[str, str] = {
"ArguAna": "Given a claim, find documents that refute the claim.",
"ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim.",
"DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia.",
"FEVER": "Given a claim, retrieve documents that support or refute the claim.",
"FiQA2018": "Given a financial question, retrieve user replies that best answer the question.",
"HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question.",
"MSMARCO": "Given a web search query, retrieve relevant passages that answer the query.",
"NFCorpus": "Given a question, retrieve relevant documents that best answer the question.",
"NQ": "Given a question, retrieve Wikipedia passages that answer the question.",
"QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question.",
"SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.",
"SciFact": "Given a scientific claim, retrieve documents that support or refute the claim.",
"Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question.",
"TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query.",
# C-MTEB eval instructions
"T2Retrieval": "为这个问题检索相关段落。",
"MMarcoRetrieval": "为这个查询检索相关段落。",
"DuRetrieval": "为这个问题检索相关百度知道回答。",
"CovidRetrieval": "为这个问题检索相关政策回答。",
"CmedqaRetrieval": "为这个医疗问题检索相关回答。",
"EcomRetrieval": "为这个查询检索相关商品标题。",
"MedicalRetrieval": "为这个医疗问题检索相关回答。",
"VideoRetrieval": "为这个电影标题检索相关段落。",
}
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
return task_name_to_instruct.get(task_name,None)
return default_instruct
def _transform_func(tokenizer: PreTrainedTokenizerFast,
examples: Dict[str, List]) -> BatchEncoding:
batch_dict = tokenizer(examples['input_texts'],
max_length=1024,
padding=True,
truncation=True)
return batch_dict
# def weighted_mean_pooling(hidden,attention_mask):
# # print(hidden.shape,attention_mask.shape)
# attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
# s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
# d = attention_mask_.sum(dim=1, keepdim=True).float()
# reps = s / d
# return reps
def mean_pooling(hidden,attention_mask):
# print(hidden.shape,attention_mask.shape)
s = torch.sum(hidden * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
def wmean_pooling(hidden,attention_mask):
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
hidden_masked = hidden * attention_mask_.unsqueeze(-1).float()
s = torch.sum(hidden_masked, dim=1)
d = attention_mask_.sum(dim=1, keepdim=True).float()
reps = s / d
return reps
def reverse_wmean_pooling(hidden,attention_mask):
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() / attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
hidden = hidden.float() * d
return hidden / torch.clamp(attention_mask_.unsqueeze(-1).float(),min=1e-9)
def sparse_pooling(head,model,items,hidden,attention_mask):
hidden = reverse_wmean_pooling(hidden,attention_mask) # reverse weighted mean pooling, beacuse the hidden states are modified in the model
max_hidden_norm = torch.max(torch.norm(hidden,dim=-1),dim = -1).values
token_weights = torch.relu(head(hidden.float()/max_hidden_norm.unsqueeze(-1).unsqueeze(-1)))
vocab_size = model.embed_tokens.weight.size(0)
input_ids = items["input_ids"]
sparse_embedding_chunks = []
mini_chunk_size = 1
mini_chunk_size = min(mini_chunk_size,hidden.shape[0])
for i in range(0, token_weights.size(0), mini_chunk_size):
now_chunk_size = min(mini_chunk_size, token_weights.size(0) - i)
sparse_embedding = torch.zeros(now_chunk_size , input_ids.size(1), vocab_size,
dtype=token_weights.dtype,
device=token_weights.device)
sparse_embedding_chunks.append(torch.max((torch.scatter(sparse_embedding, dim=-1, index=input_ids[i:i+now_chunk_size, :].unsqueeze(-1), src=token_weights[i:i+now_chunk_size, :])), dim=1).values)
sparse_embedding = torch.concat(sparse_embedding_chunks, dim=0)
unused_tokens = [0,1,2,73440]
sparse_embedding[:, unused_tokens] *= 0.
return sparse_embedding
def concat_pooling(head,model,items,hidden,attention_mask):
mean_reps = mean_pooling(hidden,attention_mask)
mean_reps = F.normalize(mean_reps, p=2, dim=1)
sparse_reps = sparse_pooling(head,model,items,hidden,attention_mask) * math.sqrt(0.3)
return torch.cat([mean_reps,sparse_reps],dim=-1)
#
class DenseEncoder(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
model_path = "openbmb/UltraRAG-Embedding"
self.encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True,attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.gpu_count = torch.cuda.device_count()
self.instruction = ""
self.encoder.eval()
self.encoder.cuda()
if self.gpu_count > 1:
self.encoder = torch.nn.DataParallel(self.encoder)
@torch.no_grad()
def encode(self, sentences,is_query=None, **kwargs) -> np.ndarray:
""" Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
"""
if is_query is not False:
sentences = [self.instruction + s for s in sentences]
dataset: Dataset = Dataset.from_dict({'input_texts': sentences})
# dataset: Dataset = Dataset.from_dict({'input_texts': ["Query: " + s for s in sentences]})
dataset.set_transform(partial(_transform_func, self.tokenizer))
data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
data_loader = DataLoader(
dataset,
batch_size=128* self.gpu_count,
shuffle=False,
drop_last=False,
num_workers=2,
collate_fn=data_collator,
pin_memory=True)
encoded_embeds = []
for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10):
with torch.cuda.amp.autocast() and torch.no_grad():
for key in batch_dict:
batch_dict[key] = batch_dict[key].to("cuda")
outputs: BaseModelOutput = self.encoder(**batch_dict)
if MODE == "Dense":
embeds = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask'])
embeds = F.normalize(embeds, p=2, dim=1)
elif MODE == "Sparse":
embeds = sparse_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask'])
else:
embeds = concat_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask'])
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
@torch.no_grad()
def encode_queries(self, queries: list[str], **kwargs) -> list[np.ndarray] | list[torch.Tensor]:
"""
Returns a list of embeddings for the given sentences.
Args:
queries: List of sentences to encode
Returns:
List of embeddings for the given sentences
"""
queries = [query for query in queries]
return self.encode(queries, is_query=True, **kwargs)
@torch.no_grad()
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
# borrowed from mteb.abstasks.AbsTaskRetrieval.DRESModel
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + " " + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
elif isinstance(corpus[0], dict):
sentences = [
(doc["title"] + " " + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
else:
sentences = corpus
is_query = False
return self.encode(sentences, is_query=is_query, **kwargs)
model = DenseEncoder()
task_names = MTEB_TASK_LIST
task_names = ["NFCorpus"]
lang = ["en","zh", "zh-CN"]
for task in task_names:
try:
evaluation = MTEB(tasks=[task], task_langs=lang)
task_cls = evaluation.tasks[0]
task_name: str = task_cls.metadata_dict["name"]
task_type: str = task_cls.metadata_dict["type"]
instruction = get_task_def_by_task_name_and_type(task_name, task_type)
model.instruction = get_detailed_instruct(instruction)
print(model.instruction)
if task == "MSMARCO":
eval_splits = ["dev"]
elif task in CMTEB_TASK_LIST:
eval_splits = task_cls.metadata_dict["eval_splits"]
else:
eval_splits = ["test"]
evaluation.run(model, eval_splits=eval_splits, overwrite_results=True)
except Exception as e:
import traceback
print(traceback.format_exc())
continue