# copy from https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct/blob/main/scripts/eval_mteb.py #### ATTENTION #### # To Reproduce the results of Sparse and Dense + Sparse, you need to hack the MTEB RetrievalEvaluator # in mteb/evaluation/evaluators/RetrievalEvaluator.py # class RetrievalEvaluator(Evaluator): # def __init__( # self, # retriever=None, # task_name: str | None = None, # k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000], # score_function: str = "cos_sim", # encode_kwargs: dict[str, Any] = {}, # **kwargs, # ): # you need to change default score_function to "dot" to reproduce the results of Sparse and Dense + Sparse MODE = "Dense" # "Dense" or "Sparse" or "Dense + Sparse" TASK_LIST_CLASSIFICATION = [ "AmazonCounterfactualClassification", "AmazonPolarityClassification", "AmazonReviewsClassification", "Banking77Classification", "EmotionClassification", "ImdbClassification", "MassiveIntentClassification", "MassiveScenarioClassification", "MTOPDomainClassification", "MTOPIntentClassification", "ToxicConversationsClassification", "TweetSentimentExtractionClassification", ] TASK_LIST_CLUSTERING = [ "ArxivClusteringP2P", "ArxivClusteringS2S", "BiorxivClusteringP2P", "BiorxivClusteringS2S", "MedrxivClusteringP2P", "MedrxivClusteringS2S", "RedditClustering", "RedditClusteringP2P", "StackExchangeClustering", "StackExchangeClusteringP2P", "TwentyNewsgroupsClustering", ] TASK_LIST_PAIR_CLASSIFICATION = [ "SprintDuplicateQuestions", "TwitterSemEval2015", "TwitterURLCorpus", ] TASK_LIST_RERANKING = [ "AskUbuntuDupQuestions", "MindSmallReranking", "SciDocsRR", "StackOverflowDupQuestions", ] TASK_LIST_RETRIEVAL = [ "ArguAna", "FiQA2018", "QuoraRetrieval", "SCIDOCS", "SciFact", "Touche2020", "TRECCOVID", "NFCorpus", "NQ", "ClimateFEVER", "CQADupstackAndroidRetrieval", "CQADupstackEnglishRetrieval", "CQADupstackGamingRetrieval", "CQADupstackGisRetrieval", "CQADupstackMathematicaRetrieval", "CQADupstackPhysicsRetrieval", "CQADupstackProgrammersRetrieval", "CQADupstackStatsRetrieval", "CQADupstackTexRetrieval", "CQADupstackUnixRetrieval", "CQADupstackWebmastersRetrieval", "CQADupstackWordpressRetrieval", "DBPedia", "HotpotQA", "MSMARCO", "FEVER", ] TASK_LIST_STS = [ "BIOSSES", "SICK-R", "STS12", "STS13", "STS14", "STS15", "STS16", "STS17", "STS22", "STSBenchmark", "SummEval", ] MTEB_TASK_LIST = ( TASK_LIST_RETRIEVAL + TASK_LIST_CLASSIFICATION + TASK_LIST_CLUSTERING + TASK_LIST_PAIR_CLASSIFICATION + TASK_LIST_RERANKING + TASK_LIST_STS ) CMTEB_TASK_LIST = [ "TNews", "IFlyTek", "MultilingualSentiment", "JDReview", "OnlineShopping", "Waimai", "AmazonReviewsClassification", "MassiveIntentClassification", "MassiveScenarioClassification", "MultilingualSentiment", "CLSClusteringS2S", "CLSClusteringP2P", "ThuNewsClusteringS2S", "ThuNewsClusteringP2P", "Ocnli", "Cmnli", "T2Reranking", "MMarcoReranking", "CMedQAv1-reranking", "CMedQAv2-reranking", "T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval", "ATEC", "BQ", "LCQMC", "PAWSX", "STSB", "AFQMC", "QBQTC", "STS22", ] MTEB_TASK_LIST = CMTEB_TASK_LIST + MTEB_TASK_LIST import torch import torch.nn.functional as F import tqdm import numpy as np import math from functools import partial from torch.utils.data import DataLoader from datasets import Dataset from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding from transformers.modeling_outputs import BaseModelOutput from typing import List, Dict from mteb import MTEB def get_detailed_instruct(task_description: str) -> str: if not task_description: return "" return "Instruction: {} Query: ".format(task_description) def get_task_def_by_task_name_and_type( task_name: str, task_type: str, default_instruct="", ): if task_type in ["STS"]: return None if task_type in ["Summarization"]: return "Given a news summary, retrieve other semantically similar summaries" if task_type in ["Classification"]: task_name_to_instruct: Dict[str, str] = { "AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual.", "AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment.", "AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category.", "Banking77Classification": "Given a online banking query, find the corresponding intents.", "EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.", "ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset.", "MassiveIntentClassification": "Given a user utterance as query, find the user intents.", "MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios.", "MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation.", "MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation.", "ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic.", "TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral.", # C-MTEB eval instructions "TNews": "根据标题确定新闻的类别。", "IFlyTek": "根据描述确定APP的类别。", "MultilingualSentiment": "将亚马逊评论分为积极、消极或中立情绪。", "JDReview": "将商品评论分为积极或消极情绪。", "OnlineShopping": "将商品评论分为积极或消极情绪。", "Waimai": "将外卖评论分为积极或消极情绪。", } return task_name_to_instruct.get(task_name,None) if task_type in ["Clustering"]: task_name_to_instruct: Dict[str, str] = { "ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts.", "ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles.", "BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts.", "BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles.", "MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts.", "MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles.", "RedditClustering": "Identify the topic or theme of Reddit posts based on the titles.", "RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts.", "StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles.", "StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs.", "TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles.", # C-MTEB eval instructions "CLSClusteringS2S": "根据标题确定文章的类别。", "CLSClusteringP2P": "根据标题和摘要确定文章的类别。", "ThuNewsClusteringS2S": "根据标题确定新闻的类别。", "ThuNewsClusteringP2P": "根据标题和摘要确定新闻的类别。", } return task_name_to_instruct.get(task_name,None) if task_type in ["Reranking", "PairClassification"]: task_name_to_instruct: Dict[str, str] = { "AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum.", "MindSmallReranking": "Retrieve relevant news articles based on user browsing history.", "SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers.", "StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum.", "SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum.", "TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet.", "TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet.", # C-MTEB eval instructions "T2Reranking": "为这个问题检索相关段落。", "MMarcoReranking": "为这个查询检索相关段落。", "CMedQAv1-reranking": "为这个医疗问题检索相关回答。", "CMedQAv2-reranking": "为这个医疗问题检索相关回答。", } return task_name_to_instruct.get(task_name,None) if task_type in ["Retrieval"]: if task_name.lower().startswith("cqadupstack"): return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question" task_name_to_instruct: Dict[str, str] = { "ArguAna": "Given a claim, find documents that refute the claim.", "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim.", "DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia.", "FEVER": "Given a claim, retrieve documents that support or refute the claim.", "FiQA2018": "Given a financial question, retrieve user replies that best answer the question.", "HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question.", "MSMARCO": "Given a web search query, retrieve relevant passages that answer the query.", "NFCorpus": "Given a question, retrieve relevant documents that best answer the question.", "NQ": "Given a question, retrieve Wikipedia passages that answer the question.", "QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question.", "SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.", "SciFact": "Given a scientific claim, retrieve documents that support or refute the claim.", "Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question.", "TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query.", # C-MTEB eval instructions "T2Retrieval": "为这个问题检索相关段落。", "MMarcoRetrieval": "为这个查询检索相关段落。", "DuRetrieval": "为这个问题检索相关百度知道回答。", "CovidRetrieval": "为这个问题检索相关政策回答。", "CmedqaRetrieval": "为这个医疗问题检索相关回答。", "EcomRetrieval": "为这个查询检索相关商品标题。", "MedicalRetrieval": "为这个医疗问题检索相关回答。", "VideoRetrieval": "为这个电影标题检索相关段落。", } task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) return task_name_to_instruct.get(task_name,None) return default_instruct def _transform_func(tokenizer: PreTrainedTokenizerFast, examples: Dict[str, List]) -> BatchEncoding: batch_dict = tokenizer(examples['input_texts'], max_length=1024, padding=True, truncation=True) return batch_dict # def weighted_mean_pooling(hidden,attention_mask): # # print(hidden.shape,attention_mask.shape) # attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) # s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1) # d = attention_mask_.sum(dim=1, keepdim=True).float() # reps = s / d # return reps def mean_pooling(hidden,attention_mask): # print(hidden.shape,attention_mask.shape) s = torch.sum(hidden * attention_mask.unsqueeze(-1).float(), dim=1) d = attention_mask.sum(dim=1, keepdim=True).float() return s / d def wmean_pooling(hidden,attention_mask): attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) hidden_masked = hidden * attention_mask_.unsqueeze(-1).float() s = torch.sum(hidden_masked, dim=1) d = attention_mask_.sum(dim=1, keepdim=True).float() reps = s / d return reps def reverse_wmean_pooling(hidden,attention_mask): attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() / attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float() hidden = hidden.float() * d return hidden / torch.clamp(attention_mask_.unsqueeze(-1).float(),min=1e-9) def sparse_pooling(head,model,items,hidden,attention_mask): hidden = reverse_wmean_pooling(hidden,attention_mask) # reverse weighted mean pooling, beacuse the hidden states are modified in the model max_hidden_norm = torch.max(torch.norm(hidden,dim=-1),dim = -1).values token_weights = torch.relu(head(hidden.float()/max_hidden_norm.unsqueeze(-1).unsqueeze(-1))) vocab_size = model.embed_tokens.weight.size(0) input_ids = items["input_ids"] sparse_embedding_chunks = [] mini_chunk_size = 1 mini_chunk_size = min(mini_chunk_size,hidden.shape[0]) for i in range(0, token_weights.size(0), mini_chunk_size): now_chunk_size = min(mini_chunk_size, token_weights.size(0) - i) sparse_embedding = torch.zeros(now_chunk_size , input_ids.size(1), vocab_size, dtype=token_weights.dtype, device=token_weights.device) sparse_embedding_chunks.append(torch.max((torch.scatter(sparse_embedding, dim=-1, index=input_ids[i:i+now_chunk_size, :].unsqueeze(-1), src=token_weights[i:i+now_chunk_size, :])), dim=1).values) sparse_embedding = torch.concat(sparse_embedding_chunks, dim=0) unused_tokens = [0,1,2,73440] sparse_embedding[:, unused_tokens] *= 0. return sparse_embedding def concat_pooling(head,model,items,hidden,attention_mask): mean_reps = mean_pooling(hidden,attention_mask) mean_reps = F.normalize(mean_reps, p=2, dim=1) sparse_reps = sparse_pooling(head,model,items,hidden,attention_mask) * math.sqrt(0.3) return torch.cat([mean_reps,sparse_reps],dim=-1) # class DenseEncoder(torch.nn.Module): def __init__(self, **kwargs): super().__init__() model_path = "openbmb/UltraRAG-Embedding" self.encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True,attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda") self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.gpu_count = torch.cuda.device_count() self.instruction = "" self.encoder.eval() self.encoder.cuda() if self.gpu_count > 1: self.encoder = torch.nn.DataParallel(self.encoder) @torch.no_grad() def encode(self, sentences,is_query=None, **kwargs) -> np.ndarray: """ Returns a list of embeddings for the given sentences. Args: sentences (`List[str]`): List of sentences to encode batch_size (`int`): Batch size for the encoding Returns: `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences """ if is_query is not False: sentences = [self.instruction + s for s in sentences] dataset: Dataset = Dataset.from_dict({'input_texts': sentences}) # dataset: Dataset = Dataset.from_dict({'input_texts': ["Query: " + s for s in sentences]}) dataset.set_transform(partial(_transform_func, self.tokenizer)) data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) data_loader = DataLoader( dataset, batch_size=128* self.gpu_count, shuffle=False, drop_last=False, num_workers=2, collate_fn=data_collator, pin_memory=True) encoded_embeds = [] for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10): with torch.cuda.amp.autocast() and torch.no_grad(): for key in batch_dict: batch_dict[key] = batch_dict[key].to("cuda") outputs: BaseModelOutput = self.encoder(**batch_dict) if MODE == "Dense": embeds = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask']) embeds = F.normalize(embeds, p=2, dim=1) elif MODE == "Sparse": embeds = sparse_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) else: embeds = concat_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) encoded_embeds.append(embeds.cpu().numpy()) return np.concatenate(encoded_embeds, axis=0) @torch.no_grad() def encode_queries(self, queries: list[str], **kwargs) -> list[np.ndarray] | list[torch.Tensor]: """ Returns a list of embeddings for the given sentences. Args: queries: List of sentences to encode Returns: List of embeddings for the given sentences """ queries = [query for query in queries] return self.encode(queries, is_query=True, **kwargs) @torch.no_grad() def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): # borrowed from mteb.abstasks.AbsTaskRetrieval.DRESModel if type(corpus) is dict: sentences = [ (corpus["title"][i] + " " + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus["text"])) ] elif isinstance(corpus[0], dict): sentences = [ (doc["title"] + " " + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus ] else: sentences = corpus is_query = False return self.encode(sentences, is_query=is_query, **kwargs) model = DenseEncoder() task_names = MTEB_TASK_LIST task_names = ["NFCorpus"] lang = ["en","zh", "zh-CN"] for task in task_names: try: evaluation = MTEB(tasks=[task], task_langs=lang) task_cls = evaluation.tasks[0] task_name: str = task_cls.metadata_dict["name"] task_type: str = task_cls.metadata_dict["type"] instruction = get_task_def_by_task_name_and_type(task_name, task_type) model.instruction = get_detailed_instruct(instruction) print(model.instruction) if task == "MSMARCO": eval_splits = ["dev"] elif task in CMTEB_TASK_LIST: eval_splits = task_cls.metadata_dict["eval_splits"] else: eval_splits = ["test"] evaluation.run(model, eval_splits=eval_splits, overwrite_results=True) except Exception as e: import traceback print(traceback.format_exc()) continue