|
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder |
|
from torch import nn |
|
import csv |
|
from torch.utils.data import DataLoader, Dataset |
|
import torch |
|
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator |
|
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator |
|
import logging |
|
import json |
|
import random |
|
import gzip |
|
|
|
model_name = 'sentence-transformers/all-MiniLM-L6-v2' |
|
|
|
train_batch_size = 128 |
|
max_seq_length = 128 |
|
num_epochs = 1 |
|
warmup_steps = 1000 |
|
model_save_path = '.' |
|
lr = 2e-5 |
|
|
|
class ESCIDataset(Dataset): |
|
def __init__(self, input): |
|
self.queries = [] |
|
with gzip.open(input) as jsonfile: |
|
for line in jsonfile.readlines(): |
|
query = json.loads(line) |
|
for i in range(1,10): |
|
if len(query['e']) > 0 and len(query['i']) > 0: |
|
p = random.choice(query['e']) |
|
positive = p['title'] |
|
n = random.choice(query['i']) |
|
negative = n['title'] |
|
self.queries.append(InputExample(texts=[query['query'], positive, negative])) |
|
|
|
def __getitem__(self, item): |
|
return self.queries[item] |
|
|
|
def __len__(self): |
|
return len(self.queries) |
|
|
|
|
|
model = SentenceTransformer(model_name) |
|
model.max_seq_length = max_seq_length |
|
|
|
|
|
train_dataset = ESCIDataset(input='train-small.json.gz') |
|
eval_dataset = ESCIDataset(input='test-small.json.gz') |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) |
|
train_loss = losses.MultipleNegativesRankingLoss(model=model) |
|
|
|
samples = {} |
|
for query in eval_dataset.queries: |
|
qstr = query.texts[0] |
|
sample = samples.get(qstr, {'query': qstr}) |
|
positive = sample.get('positive', []) |
|
positive.append(query.texts[1]) |
|
sample['positive'] = positive |
|
negative = sample.get('negative', []) |
|
negative.append(query.texts[2]) |
|
sample['negative'] = negative |
|
samples[qstr] = sample |
|
|
|
evaluator = RerankingEvaluator(samples=samples,name='esci') |
|
|
|
|
|
|
|
model.fit(train_objectives=[(train_dataloader, train_loss)], |
|
epochs=num_epochs, |
|
warmup_steps=warmup_steps, |
|
use_amp=True, |
|
checkpoint_path=model_save_path, |
|
checkpoint_save_steps=len(train_dataloader), |
|
optimizer_params = {'lr': lr}, |
|
|
|
|
|
output_path=model_save_path |
|
) |
|
|
|
|
|
|
|
model.save(model_save_path) |
|
|