infgrad commited on
Commit
afce9e0
·
verified ·
1 Parent(s): 95312db

Upload 3 files

Browse files
scripts/evaluate_en_mteb/model_for_evaluate.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import numpy as np
4
+ from typing import Sequence, Any
5
+ from mteb.encoder_interface import PromptType
6
+ from mteb import Encoder
7
+ from sentence_transformers import SentenceTransformer
8
+ from mteb_utils import get_task_def_by_task_name_and_type, get_detailed_instruct, get_task_type_en
9
+
10
+
11
+ def jasper_vl_forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
12
+ trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
13
+ if "pixel_values" in features:
14
+ trans_features["pixel_values"] = features["pixel_values"]
15
+ sentence_embedding = self.auto_model(**trans_features, **kwargs)["sentence_embedding"]
16
+ features.update({"sentence_embedding": sentence_embedding})
17
+ return features
18
+
19
+
20
+ class MTEB_Sentence_Transformer(Encoder):
21
+ def __init__(
22
+ self,
23
+ model_path_or_name: str,
24
+ lang: str,
25
+ batch_size: int,
26
+ max_length: int,
27
+ device: str | None = None
28
+ ) -> None:
29
+ super().__init__(device=device)
30
+ model = SentenceTransformer(
31
+ model_path_or_name,
32
+ trust_remote_code=True,
33
+ device="cpu",
34
+ model_kwargs={
35
+ "torch_dtype": torch.bfloat16,
36
+ "attn_implementation": "sdpa"
37
+ },
38
+ config_kwargs={"is_text_encoder": True, "vector_dim": 12288},
39
+ tokenizer_kwargs={"padding_side": "right"}
40
+ )
41
+ model._first_module().forward = functools.partial(jasper_vl_forward, model._first_module())
42
+ self.model = model
43
+
44
+ self.pool = self.model.start_multi_process_pool()
45
+ self.lang = lang
46
+ self.batch_size = batch_size
47
+ self.model.max_seq_length = max_length
48
+
49
+ def encode(
50
+ self,
51
+ sentences: Sequence[str],
52
+ *,
53
+ task_name: str,
54
+ prompt_type: PromptType | None = None,
55
+ **kwargs: Any,
56
+ ) -> np.ndarray:
57
+ task_type = get_task_type_en(task_name)
58
+ do_normalize = True
59
+ instruction = get_detailed_instruct(get_task_def_by_task_name_and_type(task_name, task_type))
60
+ if task_type == "Retrieval":
61
+ if prompt_type == "query":
62
+ # print(instruction)
63
+ sentences = [instruction + sen for sen in sentences]
64
+ elif prompt_type == "passage":
65
+ pass
66
+ else:
67
+ raise ValueError(f"unknown prompt_type:{prompt_type}")
68
+ else:
69
+ sentences = [instruction + sen for sen in sentences]
70
+ # process white space data
71
+ sentences = [i if i.strip() else "<|endoftext|>" for i in sentences]
72
+ # print("First text: ", sentences[0])
73
+ vectors = self.model.encode_multi_process(
74
+ sentences=sentences,
75
+ pool=self.pool,
76
+ batch_size=self.batch_size,
77
+ show_progress_bar=True,
78
+ normalize_embeddings=do_normalize
79
+ )
80
+
81
+ vectors = vectors.astype(dtype=np.float32)
82
+ print("vectors.shape", vectors.shape)
83
+ return vectors
scripts/evaluate_en_mteb/mteb_utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ LONG_TIME_TASK_NAMES = [
4
+ "MSMARCO",
5
+ "FEVER",
6
+ "HotpotQA",
7
+ "ClimateFEVER",
8
+ "DBPedia",
9
+ "NQ",
10
+ "ArxivClusteringP2P",
11
+ "ArxivClusteringS2S",
12
+ "RedditClusteringP2P",
13
+ "RedditClustering",
14
+ "QuoraRetrieval",
15
+ "StackExchangeClustering",
16
+ "Touche2020",
17
+ "MindSmallReranking",
18
+ "AmazonPolarityClassification",
19
+ "BiorxivClusteringP2P",
20
+ "StackExchangeClusteringP2P",
21
+ "TRECCOVID"
22
+ ]
23
+
24
+ SHORT_TIME_TASK_NAMES = [
25
+ "BIOSSES",
26
+ "STS17",
27
+ "STS16",
28
+ "AskUbuntuDupQuestions",
29
+ "SummEval",
30
+ "SciFact",
31
+ "TweetSentimentExtractionClassification",
32
+ "EmotionClassification",
33
+ "SprintDuplicateQuestions"
34
+ ]
35
+ MID_TIME_TASK_NAMES = ['BIOSSES', 'STS17', 'STS22', 'STS16', 'STSBenchmark', 'STS13', 'STS15', 'STS12', 'STS14',
36
+ 'AskUbuntuDupQuestions', 'TwitterSemEval2015', 'SummEval', 'SICK-R', 'NFCorpus', 'SciFact',
37
+ 'CQADupstackWebmastersRetrieval', 'TwitterURLCorpus', 'SprintDuplicateQuestions',
38
+ 'CQADupstackAndroidRetrieval', 'CQADupstackMathematicaRetrieval', 'ArguAna',
39
+ 'CQADupstackProgrammersRetrieval', 'SCIDOCS', 'StackOverflowDupQuestions',
40
+ 'EmotionClassification', 'TweetSentimentExtractionClassification', 'CQADupstackStatsRetrieval',
41
+ 'CQADupstackGisRetrieval', 'CQADupstackWordpressRetrieval', 'CQADupstackEnglishRetrieval',
42
+ 'CQADupstackPhysicsRetrieval', 'CQADupstackGamingRetrieval', 'SciDocsRR', 'FiQA2018',
43
+ 'CQADupstackUnixRetrieval', 'ToxicConversationsClassification', 'Banking77Classification',
44
+ 'TwentyNewsgroupsClustering', 'MedrxivClusteringS2S', 'ImdbClassification',
45
+ 'MTOPDomainClassification', 'BiorxivClusteringS2S', 'AmazonCounterfactualClassification',
46
+ 'MassiveScenarioClassification', 'MedrxivClusteringP2P', 'MTOPIntentClassification',
47
+ 'MassiveIntentClassification', 'CQADupstackTexRetrieval', 'AmazonReviewsClassification',
48
+ 'TRECCOVID', 'BiorxivClusteringP2P', 'StackExchangeClusteringP2P', 'StackExchangeClustering']
49
+
50
+ CMTEB_TASK_LIST = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
51
+ 'AmazonReviewsClassification', 'MassiveIntentClassification', 'MassiveScenarioClassification',
52
+ 'MultilingualSentiment',
53
+ 'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
54
+ 'Ocnli', 'Cmnli',
55
+ 'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
56
+ 'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval',
57
+ 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
58
+ 'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
59
+
60
+ TASK_LIST_CLASSIFICATION = [
61
+ "AmazonCounterfactualClassification",
62
+ "AmazonPolarityClassification",
63
+ "AmazonReviewsClassification",
64
+ "Banking77Classification",
65
+ "EmotionClassification",
66
+ "ImdbClassification",
67
+ "MassiveIntentClassification",
68
+ "MassiveScenarioClassification",
69
+ "MTOPDomainClassification",
70
+ "MTOPIntentClassification",
71
+ "ToxicConversationsClassification",
72
+ "TweetSentimentExtractionClassification",
73
+ ]
74
+
75
+ TASK_LIST_CLUSTERING = [
76
+ "ArxivClusteringP2P",
77
+ "ArxivClusteringS2S",
78
+ "BiorxivClusteringP2P",
79
+ "BiorxivClusteringS2S",
80
+ "MedrxivClusteringP2P",
81
+ "MedrxivClusteringS2S",
82
+ "RedditClustering",
83
+ "RedditClusteringP2P",
84
+ "StackExchangeClustering",
85
+ "StackExchangeClusteringP2P",
86
+ "TwentyNewsgroupsClustering",
87
+ ]
88
+
89
+ TASK_LIST_PAIR_CLASSIFICATION = [
90
+ "SprintDuplicateQuestions",
91
+ "TwitterSemEval2015",
92
+ "TwitterURLCorpus",
93
+ ]
94
+
95
+ TASK_LIST_RERANKING = [
96
+ "AskUbuntuDupQuestions",
97
+ "MindSmallReranking",
98
+ "SciDocsRR",
99
+ "StackOverflowDupQuestions",
100
+ ]
101
+
102
+ TASK_LIST_RETRIEVAL = [
103
+ "ArguAna",
104
+ "CQADupstackAndroidRetrieval",
105
+ "CQADupstackEnglishRetrieval",
106
+ "CQADupstackGamingRetrieval",
107
+ "CQADupstackGisRetrieval",
108
+ "CQADupstackMathematicaRetrieval",
109
+ "CQADupstackPhysicsRetrieval",
110
+ "CQADupstackProgrammersRetrieval",
111
+ "CQADupstackStatsRetrieval",
112
+ "CQADupstackTexRetrieval",
113
+ "CQADupstackUnixRetrieval",
114
+ "CQADupstackWebmastersRetrieval",
115
+ "CQADupstackWordpressRetrieval",
116
+ "DBPedia",
117
+ "FEVER",
118
+ "FiQA2018",
119
+ "NFCorpus",
120
+ "NQ",
121
+ "QuoraRetrieval",
122
+ "SCIDOCS",
123
+ "SciFact",
124
+ "Touche2020",
125
+ "TRECCOVID",
126
+ "ClimateFEVER",
127
+ "HotpotQA",
128
+ "MSMARCO",
129
+ ]
130
+
131
+ TASK_LIST_STS = [
132
+ "BIOSSES",
133
+ "SICK-R",
134
+ "STS12",
135
+ "STS13",
136
+ "STS14",
137
+ "STS15",
138
+ "STS16",
139
+ "STS17",
140
+ "STS22",
141
+ "STSBenchmark",
142
+ "SummEval",
143
+ ]
144
+
145
+ MTEB_TASK_LIST = (
146
+ TASK_LIST_CLASSIFICATION
147
+ + TASK_LIST_CLUSTERING
148
+ + TASK_LIST_PAIR_CLASSIFICATION
149
+ + TASK_LIST_RERANKING
150
+ + TASK_LIST_STS
151
+ + TASK_LIST_RETRIEVAL
152
+ )
153
+
154
+
155
+ def get_task_type_en(task_name: str):
156
+ if task_name == "SummEval":
157
+ return "Summarization"
158
+ if task_name in TASK_LIST_CLASSIFICATION:
159
+ return "Classification"
160
+ if task_name in TASK_LIST_CLUSTERING:
161
+ return "Clustering"
162
+ if task_name in TASK_LIST_PAIR_CLASSIFICATION:
163
+ return "PairClassification"
164
+ if task_name in TASK_LIST_RERANKING:
165
+ return "Reranking"
166
+ if task_name in TASK_LIST_STS:
167
+ return "STS"
168
+ if task_name in TASK_LIST_RETRIEVAL:
169
+ return "Retrieval"
170
+ raise ValueError(f"unknown task name:{task_name}")
171
+
172
+
173
+ def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str:
174
+ if task_type in ['STS']:
175
+ return "Retrieve semantically similar text."
176
+
177
+ if task_type in ['Summarization']:
178
+ return "Given a news summary, retrieve other semantically similar summaries"
179
+
180
+ if task_type in ['BitextMining']:
181
+ return "Retrieve parallel sentences."
182
+
183
+ if task_type in ['Classification']:
184
+ task_name_to_instruct: Dict[str, str] = {
185
+ 'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual',
186
+ 'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment',
187
+ 'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category',
188
+ 'Banking77Classification': 'Given a online banking query, find the corresponding intents',
189
+ 'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise',
190
+ 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset',
191
+ 'MassiveIntentClassification': 'Given a user utterance as query, find the user intents',
192
+ 'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios',
193
+ 'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation',
194
+ 'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation',
195
+ 'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic',
196
+ 'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral',
197
+ # C-MTEB eval instructions
198
+ 'TNews': 'Classify the fine-grained category of the given news title',
199
+ 'IFlyTek': 'Given an App description text, find the appropriate fine-grained category',
200
+ 'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative',
201
+ 'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative',
202
+ 'OnlineShopping': 'Classify the customer review for online shopping into positive or negative',
203
+ 'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative',
204
+ }
205
+ return task_name_to_instruct[task_name]
206
+
207
+ if task_type in ['Clustering']:
208
+ task_name_to_instruct: Dict[str, str] = {
209
+ 'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts',
210
+ 'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles',
211
+ 'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts',
212
+ 'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles',
213
+ 'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts',
214
+ 'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles',
215
+ 'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles',
216
+ 'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts',
217
+ 'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles',
218
+ 'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs',
219
+ 'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles',
220
+ # C-MTEB eval instructions
221
+ 'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles',
222
+ 'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts',
223
+ 'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles',
224
+ 'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents',
225
+ }
226
+ return task_name_to_instruct[task_name]
227
+
228
+ if task_type in ['Reranking', 'PairClassification']:
229
+ task_name_to_instruct: Dict[str, str] = {
230
+ 'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum',
231
+ 'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history',
232
+ 'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers',
233
+ 'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum',
234
+ 'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum',
235
+ 'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet',
236
+ 'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet',
237
+ # C-MTEB eval instructions
238
+ 'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question',
239
+ 'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question',
240
+ 'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question',
241
+ 'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question',
242
+ 'Ocnli': 'Retrieve semantically similar text.',
243
+ 'Cmnli': 'Retrieve semantically similar text.',
244
+ }
245
+ return task_name_to_instruct[task_name]
246
+
247
+ if task_type in ['Retrieval']:
248
+ if task_name.lower().startswith('cqadupstack'):
249
+ return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question'
250
+
251
+ task_name_to_instruct: Dict[str, str] = {
252
+ 'ArguAna': 'Given a claim, find documents that refute the claim',
253
+ 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim',
254
+ 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia',
255
+ 'FEVER': 'Given a claim, retrieve documents that support or refute the claim',
256
+ 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question',
257
+ 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question',
258
+ 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.',
259
+ 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question',
260
+ 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question',
261
+ 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question',
262
+ 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
263
+ 'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim',
264
+ 'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question',
265
+ 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query',
266
+ # C-MTEB eval instructions
267
+ 'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question',
268
+ 'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query',
269
+ 'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question',
270
+ 'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question',
271
+ 'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question',
272
+ 'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products',
273
+ 'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question',
274
+ 'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos',
275
+ }
276
+
277
+ # add lower case keys to match some beir names
278
+ task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
279
+ # other cases where lower case match still doesn't work
280
+ task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
281
+ task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
282
+ task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
283
+ task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
284
+ task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
285
+ task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
286
+
287
+ # for miracl evaluation
288
+ task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question'
289
+
290
+ return task_name_to_instruct[task_name]
291
+
292
+ raise ValueError(f"No instruction config for task {task_name} with type {task_type}")
293
+
294
+
295
+ def get_detailed_instruct(task_description: str) -> str:
296
+ if not task_description:
297
+ return ''
298
+
299
+ return 'Instruct: {}\nQuery: '.format(task_description)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ print(len(MTEB_TASK_LIST))
scripts/evaluate_en_mteb/run_evaluate_mteb.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Please comment the following line of code according to the actual situation
3
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
+
5
+ import mteb
6
+ from model_for_evaluate import MTEB_Sentence_Transformer
7
+
8
+ if __name__ == "__main__":
9
+ model_name = "valid_jasper"
10
+ model = MTEB_Sentence_Transformer(
11
+ model_path_or_name="infgrad/jasper_en_vision_language_v1",
12
+ lang="en",
13
+ batch_size=27,
14
+ max_length=400,
15
+ )
16
+ tasks = list(mteb.get_benchmark("MTEB(eng, classic)"))
17
+ evaluation = mteb.MTEB(tasks=tasks)
18
+ evaluation.run(
19
+ model,
20
+ output_folder=f"./en_results/{model_name}",
21
+ overwrite_results=False,
22
+ verbosity=3
23
+ )
24
+ model.model.stop_multi_process_pool(model.pool)