File size: 3,057 Bytes
0b16e3b f8ba3aa 0b16e3b 7db4440 0b16e3b 3669fe2 0b16e3b 3669fe2 0b16e3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder
from torch import nn
import csv
from torch.utils.data import DataLoader, Dataset
import torch
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import logging
import json
import random
import gzip
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
train_batch_size = 100
max_seq_length = 128
num_epochs = 1
warmup_steps = 1000
model_save_path = '.'
lr = 2e-5
class ESCIDataset(Dataset):
def __init__(self, input):
self.queries = []
with gzip.open(input) as jsonfile:
for line in jsonfile.readlines():
query = json.loads(line)
for p in query['e']:
positive = p['title']
for n in query['i']:
negative = n['title']
self.queries.append(InputExample(texts=[query['query'], positive, negative]))
for p in query['s']:
positive = p['title']
for n in query['i']:
negative = n['title']
self.queries.append(InputExample(texts=[query['query'], positive, negative]))
for p in query['c']:
positive = p['title']
for n in query['i']:
negative = n['title']
self.queries.append(InputExample(texts=[query['query'], positive, negative]))
def __getitem__(self, item):
return self.queries[item]
def __len__(self):
return len(self.queries)
model = SentenceTransformer(model_name)
model.max_seq_length = max_seq_length
train_dataset = ESCIDataset(input='train-small.json.gz')
eval_dataset = ESCIDataset(input='test-small.json.gz')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)
samples = {}
for query in eval_dataset.queries:
qstr = query.texts[0]
sample = samples.get(qstr, {'query': qstr})
positive = sample.get('positive', [])
positive.append(query.texts[1])
sample['positive'] = positive
negative = sample.get('negative', [])
negative.append(query.texts[2])
sample['negative'] = negative
samples[qstr] = sample
evaluator = RerankingEvaluator(samples=samples,name='esci')
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=warmup_steps,
use_amp=True,
# checkpoint_path=model_save_path,
# checkpoint_save_steps=len(train_dataloader),
optimizer_params = {'lr': lr},
evaluator=evaluator,
# evaluation_steps=1000,
output_path=model_save_path
)
# Save the model
model.save(model_save_path)
|