File size: 3,375 Bytes
f80e0e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder
from torch import nn
import csv
from torch.utils.data import DataLoader, Dataset
import torch
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import logging
import json
import random
import gzip

model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2'

train_batch_size = 64
max_seq_length = 128
num_epochs = 1
warmup_steps = 1000
model_save_path = '.'
lr = 2e-5

class ESCIDataset(Dataset):
    def __init__(self, input):
        self.queries = []
        self.posneg = []
        with gzip.open(input) as jsonfile:
            for line in jsonfile.readlines():
                query = json.loads(line)
                for doc in query['e']:
                    self.queries.append(InputExample(texts=[query['query'], doc['title']], label=1.0))
                for doc in query['s']:
                    self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.1))
                for doc in query['c']:
                    self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.01))
                for doc in query['i']:
                    self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.0))

    def __getitem__(self, item):
        return self.queries[item]

    def __len__(self):
        return len(self.queries)

class ESCIEvalDataset(Dataset):
    def __init__(self, input):
        self.queries = []
        with gzip.open(input) as jsonfile:
            for line in jsonfile.readlines():
                query = json.loads(line)
                if len(query['e']) > 0 and len(query['i']) > 0:
                    for p in query['e']:
                        positive = p['title']
                        for n in query['i']:
                            negative = n['title']
                            self.queries.append(InputExample(texts=[query['query'], positive, negative]))

    def __getitem__(self, item):
        return self.queries[item]

    def __len__(self):
        return len(self.queries)

model = CrossEncoder(model_name, num_labels=1)
model.max_seq_length = max_seq_length


train_dataset = ESCIDataset(input='train-small.json.gz')
eval_dataset = ESCIEvalDataset(input='test-small.json.gz')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)

samples = {}
for query in eval_dataset.queries:
    qstr = query.texts[0]
    sample = samples.get(qstr, {'query': qstr})
    positive = sample.get('positive', [])
    positive.append(query.texts[1])
    sample['positive'] = positive
    negative = sample.get('negative', [])
    negative.append(query.texts[2])
    sample['negative'] = negative
    samples[qstr] = sample

evaluator = CERerankingEvaluator(samples=samples,name='esci')

# Train the model

model.fit(train_dataloader=train_dataloader,
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          use_amp=True,
          optimizer_params = {'lr': lr},
          evaluator=evaluator,
#          evaluation_steps=1000,
          output_path=model_save_path
          )

# Save the model

model.save(model_save_path)