File size: 3,057 Bytes
0b16e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ba3aa
0b16e3b
 
 
 
 
 
 
 
 
 
 
 
7db4440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b16e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3669fe2
 
0b16e3b
3669fe2
 
0b16e3b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder
from torch import nn
import csv
from torch.utils.data import DataLoader, Dataset
import torch
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import logging
import json
import random
import gzip

model_name = 'sentence-transformers/all-MiniLM-L6-v2'

train_batch_size = 100
max_seq_length = 128
num_epochs = 1
warmup_steps = 1000
model_save_path = '.'
lr = 2e-5

class ESCIDataset(Dataset):
    def __init__(self, input):
        self.queries = []
        with gzip.open(input) as jsonfile:
            for line in jsonfile.readlines():
                query = json.loads(line)
                for p in query['e']:
                    positive = p['title']
                    for n in query['i']:
                        negative = n['title']
                        self.queries.append(InputExample(texts=[query['query'], positive, negative]))
                for p in query['s']:
                    positive = p['title']
                    for n in query['i']:
                        negative = n['title']
                        self.queries.append(InputExample(texts=[query['query'], positive, negative]))
                for p in query['c']:
                    positive = p['title']
                    for n in query['i']:
                        negative = n['title']
                        self.queries.append(InputExample(texts=[query['query'], positive, negative]))

    def __getitem__(self, item):
        return self.queries[item]

    def __len__(self):
        return len(self.queries)


model = SentenceTransformer(model_name)
model.max_seq_length = max_seq_length


train_dataset = ESCIDataset(input='train-small.json.gz')
eval_dataset = ESCIDataset(input='test-small.json.gz')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

samples = {}
for query in eval_dataset.queries:
    qstr = query.texts[0]
    sample = samples.get(qstr, {'query': qstr})
    positive = sample.get('positive', [])
    positive.append(query.texts[1])
    sample['positive'] = positive
    negative = sample.get('negative', [])
    negative.append(query.texts[2])
    sample['negative'] = negative
    samples[qstr] = sample

evaluator = RerankingEvaluator(samples=samples,name='esci')

# Train the model

model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          use_amp=True,
#          checkpoint_path=model_save_path,
#          checkpoint_save_steps=len(train_dataloader),
          optimizer_params = {'lr': lr},
          evaluator=evaluator,
#          evaluation_steps=1000,
          output_path=model_save_path
          )

# Save the model

model.save(model_save_path)