from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder from torch import nn import csv from torch.utils.data import DataLoader, Dataset import torch from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator import logging import json import random import gzip model_name = 'sentence-transformers/all-MiniLM-L6-v2' train_batch_size = 128 max_seq_length = 128 num_epochs = 1 warmup_steps = 1000 model_save_path = '.' lr = 2e-5 class ESCIDataset(Dataset): def __init__(self, input): self.queries = [] with gzip.open(input) as jsonfile: for line in jsonfile.readlines(): query = json.loads(line) for i in range(1,10): if len(query['e']) > 0 and len(query['i']) > 0: p = random.choice(query['e']) positive = p['title'] n = random.choice(query['i']) negative = n['title'] self.queries.append(InputExample(texts=[query['query'], positive, negative])) def __getitem__(self, item): return self.queries[item] def __len__(self): return len(self.queries) model = SentenceTransformer(model_name) model.max_seq_length = max_seq_length train_dataset = ESCIDataset(input='train-small.json.gz') eval_dataset = ESCIDataset(input='test-small.json.gz') train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) train_loss = losses.MultipleNegativesRankingLoss(model=model) samples = {} for query in eval_dataset.queries: qstr = query.texts[0] sample = samples.get(qstr, {'query': qstr}) positive = sample.get('positive', []) positive.append(query.texts[1]) sample['positive'] = positive negative = sample.get('negative', []) negative.append(query.texts[2]) sample['negative'] = negative samples[qstr] = sample evaluator = RerankingEvaluator(samples=samples,name='esci') # Train the model model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps, use_amp=True, checkpoint_path=model_save_path, checkpoint_save_steps=len(train_dataloader), optimizer_params = {'lr': lr}, # evaluator=evaluator, # evaluation_steps=300, output_path=model_save_path ) # Save the model model.save(model_save_path)