|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODE = "Dense" |
|
|
|
TASK_LIST_CLASSIFICATION = [ |
|
"AmazonCounterfactualClassification", |
|
"AmazonPolarityClassification", |
|
"AmazonReviewsClassification", |
|
"Banking77Classification", |
|
"EmotionClassification", |
|
"ImdbClassification", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"MTOPDomainClassification", |
|
"MTOPIntentClassification", |
|
"ToxicConversationsClassification", |
|
"TweetSentimentExtractionClassification", |
|
] |
|
|
|
TASK_LIST_CLUSTERING = [ |
|
"ArxivClusteringP2P", |
|
"ArxivClusteringS2S", |
|
"BiorxivClusteringP2P", |
|
"BiorxivClusteringS2S", |
|
"MedrxivClusteringP2P", |
|
"MedrxivClusteringS2S", |
|
"RedditClustering", |
|
"RedditClusteringP2P", |
|
"StackExchangeClustering", |
|
"StackExchangeClusteringP2P", |
|
"TwentyNewsgroupsClustering", |
|
] |
|
|
|
TASK_LIST_PAIR_CLASSIFICATION = [ |
|
"SprintDuplicateQuestions", |
|
"TwitterSemEval2015", |
|
"TwitterURLCorpus", |
|
] |
|
|
|
TASK_LIST_RERANKING = [ |
|
"AskUbuntuDupQuestions", |
|
"MindSmallReranking", |
|
"SciDocsRR", |
|
"StackOverflowDupQuestions", |
|
] |
|
|
|
TASK_LIST_RETRIEVAL = [ |
|
"ArguAna", |
|
"FiQA2018", |
|
"QuoraRetrieval", |
|
"SCIDOCS", |
|
"SciFact", |
|
"Touche2020", |
|
"TRECCOVID", |
|
"NFCorpus", |
|
"NQ", |
|
"ClimateFEVER", |
|
"CQADupstackAndroidRetrieval", |
|
"CQADupstackEnglishRetrieval", |
|
"CQADupstackGamingRetrieval", |
|
"CQADupstackGisRetrieval", |
|
"CQADupstackMathematicaRetrieval", |
|
"CQADupstackPhysicsRetrieval", |
|
"CQADupstackProgrammersRetrieval", |
|
"CQADupstackStatsRetrieval", |
|
"CQADupstackTexRetrieval", |
|
"CQADupstackUnixRetrieval", |
|
"CQADupstackWebmastersRetrieval", |
|
"CQADupstackWordpressRetrieval", |
|
"DBPedia", |
|
"HotpotQA", |
|
"MSMARCO", |
|
"FEVER", |
|
] |
|
|
|
TASK_LIST_STS = [ |
|
"BIOSSES", |
|
"SICK-R", |
|
"STS12", |
|
"STS13", |
|
"STS14", |
|
"STS15", |
|
"STS16", |
|
"STS17", |
|
"STS22", |
|
"STSBenchmark", |
|
"SummEval", |
|
] |
|
|
|
MTEB_TASK_LIST = ( |
|
TASK_LIST_RETRIEVAL |
|
+ TASK_LIST_CLASSIFICATION |
|
+ TASK_LIST_CLUSTERING |
|
+ TASK_LIST_PAIR_CLASSIFICATION |
|
+ TASK_LIST_RERANKING |
|
+ TASK_LIST_STS |
|
) |
|
|
|
|
|
CMTEB_TASK_LIST = [ |
|
"TNews", |
|
"IFlyTek", |
|
"MultilingualSentiment", |
|
"JDReview", |
|
"OnlineShopping", |
|
"Waimai", |
|
"AmazonReviewsClassification", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"MultilingualSentiment", |
|
"CLSClusteringS2S", |
|
"CLSClusteringP2P", |
|
"ThuNewsClusteringS2S", |
|
"ThuNewsClusteringP2P", |
|
"Ocnli", |
|
"Cmnli", |
|
"T2Reranking", |
|
"MMarcoReranking", |
|
"CMedQAv1-reranking", |
|
"CMedQAv2-reranking", |
|
"T2Retrieval", |
|
"MMarcoRetrieval", |
|
"DuRetrieval", |
|
"CovidRetrieval", |
|
"CmedqaRetrieval", |
|
"EcomRetrieval", |
|
"MedicalRetrieval", |
|
"VideoRetrieval", |
|
"ATEC", |
|
"BQ", |
|
"LCQMC", |
|
"PAWSX", |
|
"STSB", |
|
"AFQMC", |
|
"QBQTC", |
|
"STS22", |
|
] |
|
|
|
MTEB_TASK_LIST = CMTEB_TASK_LIST + MTEB_TASK_LIST |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import tqdm |
|
import numpy as np |
|
import math |
|
|
|
from functools import partial |
|
from torch.utils.data import DataLoader |
|
from datasets import Dataset |
|
from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from typing import List, Dict |
|
from mteb import MTEB |
|
|
|
def get_detailed_instruct(task_description: str) -> str: |
|
if not task_description: |
|
return "" |
|
|
|
return "Instruction: {} Query: ".format(task_description) |
|
|
|
|
|
|
|
def get_task_def_by_task_name_and_type( |
|
task_name: str, |
|
task_type: str, |
|
default_instruct="", |
|
): |
|
if task_type in ["STS"]: |
|
return None |
|
|
|
if task_type in ["Summarization"]: |
|
return "Given a news summary, retrieve other semantically similar summaries" |
|
|
|
if task_type in ["Classification"]: |
|
task_name_to_instruct: Dict[str, str] = { |
|
"AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual.", |
|
"AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment.", |
|
"AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category.", |
|
"Banking77Classification": "Given a online banking query, find the corresponding intents.", |
|
"EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.", |
|
"ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset.", |
|
"MassiveIntentClassification": "Given a user utterance as query, find the user intents.", |
|
"MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios.", |
|
"MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation.", |
|
"MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation.", |
|
"ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic.", |
|
"TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral.", |
|
|
|
"TNews": "根据标题确定新闻的类别。", |
|
"IFlyTek": "根据描述确定APP的类别。", |
|
"MultilingualSentiment": "将亚马逊评论分为积极、消极或中立情绪。", |
|
"JDReview": "将商品评论分为积极或消极情绪。", |
|
"OnlineShopping": "将商品评论分为积极或消极情绪。", |
|
"Waimai": "将外卖评论分为积极或消极情绪。", |
|
} |
|
return task_name_to_instruct.get(task_name,None) |
|
|
|
if task_type in ["Clustering"]: |
|
task_name_to_instruct: Dict[str, str] = { |
|
"ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts.", |
|
"ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles.", |
|
"BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts.", |
|
"BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles.", |
|
"MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts.", |
|
"MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles.", |
|
"RedditClustering": "Identify the topic or theme of Reddit posts based on the titles.", |
|
"RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts.", |
|
"StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles.", |
|
"StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs.", |
|
"TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles.", |
|
|
|
"CLSClusteringS2S": "根据标题确定文章的类别。", |
|
"CLSClusteringP2P": "根据标题和摘要确定文章的类别。", |
|
"ThuNewsClusteringS2S": "根据标题确定新闻的类别。", |
|
"ThuNewsClusteringP2P": "根据标题和摘要确定新闻的类别。", |
|
} |
|
return task_name_to_instruct.get(task_name,None) |
|
|
|
if task_type in ["Reranking", "PairClassification"]: |
|
task_name_to_instruct: Dict[str, str] = { |
|
"AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum.", |
|
"MindSmallReranking": "Retrieve relevant news articles based on user browsing history.", |
|
"SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers.", |
|
"StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum.", |
|
"SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum.", |
|
"TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet.", |
|
"TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet.", |
|
|
|
"T2Reranking": "为这个问题检索相关段落。", |
|
"MMarcoReranking": "为这个查询检索相关段落。", |
|
"CMedQAv1-reranking": "为这个医疗问题检索相关回答。", |
|
"CMedQAv2-reranking": "为这个医疗问题检索相关回答。", |
|
} |
|
|
|
return task_name_to_instruct.get(task_name,None) |
|
|
|
if task_type in ["Retrieval"]: |
|
if task_name.lower().startswith("cqadupstack"): |
|
return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question" |
|
|
|
task_name_to_instruct: Dict[str, str] = { |
|
"ArguAna": "Given a claim, find documents that refute the claim.", |
|
"ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim.", |
|
"DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia.", |
|
"FEVER": "Given a claim, retrieve documents that support or refute the claim.", |
|
"FiQA2018": "Given a financial question, retrieve user replies that best answer the question.", |
|
"HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question.", |
|
"MSMARCO": "Given a web search query, retrieve relevant passages that answer the query.", |
|
"NFCorpus": "Given a question, retrieve relevant documents that best answer the question.", |
|
"NQ": "Given a question, retrieve Wikipedia passages that answer the question.", |
|
"QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question.", |
|
"SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.", |
|
"SciFact": "Given a scientific claim, retrieve documents that support or refute the claim.", |
|
"Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question.", |
|
"TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query.", |
|
|
|
"T2Retrieval": "为这个问题检索相关段落。", |
|
"MMarcoRetrieval": "为这个查询检索相关段落。", |
|
"DuRetrieval": "为这个问题检索相关百度知道回答。", |
|
"CovidRetrieval": "为这个问题检索相关政策回答。", |
|
"CmedqaRetrieval": "为这个医疗问题检索相关回答。", |
|
"EcomRetrieval": "为这个查询检索相关商品标题。", |
|
"MedicalRetrieval": "为这个医疗问题检索相关回答。", |
|
"VideoRetrieval": "为这个电影标题检索相关段落。", |
|
} |
|
|
|
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) |
|
|
|
return task_name_to_instruct.get(task_name,None) |
|
return default_instruct |
|
def _transform_func(tokenizer: PreTrainedTokenizerFast, |
|
examples: Dict[str, List]) -> BatchEncoding: |
|
batch_dict = tokenizer(examples['input_texts'], |
|
max_length=1024, |
|
padding=True, |
|
truncation=True) |
|
|
|
return batch_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mean_pooling(hidden,attention_mask): |
|
|
|
s = torch.sum(hidden * attention_mask.unsqueeze(-1).float(), dim=1) |
|
d = attention_mask.sum(dim=1, keepdim=True).float() |
|
return s / d |
|
|
|
def wmean_pooling(hidden,attention_mask): |
|
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
|
hidden_masked = hidden * attention_mask_.unsqueeze(-1).float() |
|
s = torch.sum(hidden_masked, dim=1) |
|
d = attention_mask_.sum(dim=1, keepdim=True).float() |
|
reps = s / d |
|
return reps |
|
|
|
def reverse_wmean_pooling(hidden,attention_mask): |
|
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
|
d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() / attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float() |
|
hidden = hidden.float() * d |
|
return hidden / torch.clamp(attention_mask_.unsqueeze(-1).float(),min=1e-9) |
|
|
|
|
|
def sparse_pooling(head,model,items,hidden,attention_mask): |
|
hidden = reverse_wmean_pooling(hidden,attention_mask) |
|
max_hidden_norm = torch.max(torch.norm(hidden,dim=-1),dim = -1).values |
|
token_weights = torch.relu(head(hidden.float()/max_hidden_norm.unsqueeze(-1).unsqueeze(-1))) |
|
vocab_size = model.embed_tokens.weight.size(0) |
|
input_ids = items["input_ids"] |
|
sparse_embedding_chunks = [] |
|
mini_chunk_size = 1 |
|
mini_chunk_size = min(mini_chunk_size,hidden.shape[0]) |
|
for i in range(0, token_weights.size(0), mini_chunk_size): |
|
now_chunk_size = min(mini_chunk_size, token_weights.size(0) - i) |
|
sparse_embedding = torch.zeros(now_chunk_size , input_ids.size(1), vocab_size, |
|
dtype=token_weights.dtype, |
|
device=token_weights.device) |
|
sparse_embedding_chunks.append(torch.max((torch.scatter(sparse_embedding, dim=-1, index=input_ids[i:i+now_chunk_size, :].unsqueeze(-1), src=token_weights[i:i+now_chunk_size, :])), dim=1).values) |
|
sparse_embedding = torch.concat(sparse_embedding_chunks, dim=0) |
|
unused_tokens = [0,1,2,73440] |
|
sparse_embedding[:, unused_tokens] *= 0. |
|
return sparse_embedding |
|
|
|
def concat_pooling(head,model,items,hidden,attention_mask): |
|
mean_reps = mean_pooling(hidden,attention_mask) |
|
mean_reps = F.normalize(mean_reps, p=2, dim=1) |
|
sparse_reps = sparse_pooling(head,model,items,hidden,attention_mask) * math.sqrt(0.3) |
|
return torch.cat([mean_reps,sparse_reps],dim=-1) |
|
|
|
|
|
|
|
class DenseEncoder(torch.nn.Module): |
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
|
|
model_path = "openbmb/UltraRAG-Embedding" |
|
self.encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True,attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
self.gpu_count = torch.cuda.device_count() |
|
self.instruction = "" |
|
|
|
self.encoder.eval() |
|
self.encoder.cuda() |
|
|
|
if self.gpu_count > 1: |
|
self.encoder = torch.nn.DataParallel(self.encoder) |
|
|
|
@torch.no_grad() |
|
def encode(self, sentences,is_query=None, **kwargs) -> np.ndarray: |
|
""" Returns a list of embeddings for the given sentences. |
|
Args: |
|
sentences (`List[str]`): List of sentences to encode |
|
batch_size (`int`): Batch size for the encoding |
|
|
|
Returns: |
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences |
|
""" |
|
if is_query is not False: |
|
sentences = [self.instruction + s for s in sentences] |
|
dataset: Dataset = Dataset.from_dict({'input_texts': sentences}) |
|
|
|
|
|
dataset.set_transform(partial(_transform_func, self.tokenizer)) |
|
|
|
data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) |
|
data_loader = DataLoader( |
|
dataset, |
|
batch_size=128* self.gpu_count, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=2, |
|
collate_fn=data_collator, |
|
pin_memory=True) |
|
|
|
encoded_embeds = [] |
|
for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10): |
|
|
|
with torch.cuda.amp.autocast() and torch.no_grad(): |
|
for key in batch_dict: |
|
batch_dict[key] = batch_dict[key].to("cuda") |
|
outputs: BaseModelOutput = self.encoder(**batch_dict) |
|
if MODE == "Dense": |
|
embeds = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
embeds = F.normalize(embeds, p=2, dim=1) |
|
elif MODE == "Sparse": |
|
embeds = sparse_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) |
|
else: |
|
embeds = concat_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) |
|
encoded_embeds.append(embeds.cpu().numpy()) |
|
|
|
return np.concatenate(encoded_embeds, axis=0) |
|
|
|
@torch.no_grad() |
|
def encode_queries(self, queries: list[str], **kwargs) -> list[np.ndarray] | list[torch.Tensor]: |
|
""" |
|
Returns a list of embeddings for the given sentences. |
|
Args: |
|
queries: List of sentences to encode |
|
|
|
Returns: |
|
List of embeddings for the given sentences |
|
""" |
|
|
|
|
|
queries = [query for query in queries] |
|
return self.encode(queries, is_query=True, **kwargs) |
|
|
|
@torch.no_grad() |
|
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): |
|
|
|
if type(corpus) is dict: |
|
sentences = [ |
|
(corpus["title"][i] + " " + corpus["text"][i]).strip() |
|
if "title" in corpus |
|
else corpus["text"][i].strip() |
|
for i in range(len(corpus["text"])) |
|
] |
|
elif isinstance(corpus[0], dict): |
|
sentences = [ |
|
(doc["title"] + " " + doc["text"]).strip() |
|
if "title" in doc |
|
else doc["text"].strip() |
|
for doc in corpus |
|
] |
|
else: |
|
sentences = corpus |
|
is_query = False |
|
return self.encode(sentences, is_query=is_query, **kwargs) |
|
|
|
|
|
model = DenseEncoder() |
|
task_names = MTEB_TASK_LIST |
|
task_names = ["NFCorpus"] |
|
lang = ["en","zh", "zh-CN"] |
|
|
|
for task in task_names: |
|
try: |
|
evaluation = MTEB(tasks=[task], task_langs=lang) |
|
task_cls = evaluation.tasks[0] |
|
task_name: str = task_cls.metadata_dict["name"] |
|
task_type: str = task_cls.metadata_dict["type"] |
|
instruction = get_task_def_by_task_name_and_type(task_name, task_type) |
|
model.instruction = get_detailed_instruct(instruction) |
|
print(model.instruction) |
|
if task == "MSMARCO": |
|
eval_splits = ["dev"] |
|
elif task in CMTEB_TASK_LIST: |
|
eval_splits = task_cls.metadata_dict["eval_splits"] |
|
else: |
|
eval_splits = ["test"] |
|
evaluation.run(model, eval_splits=eval_splits, overwrite_results=True) |
|
|
|
except Exception as e: |
|
import traceback |
|
print(traceback.format_exc()) |
|
continue |