from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder from torch import nn import csv from torch.utils.data import DataLoader, Dataset import torch from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator import logging import json import random import gzip model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2' train_batch_size = 64 max_seq_length = 128 num_epochs = 1 warmup_steps = 1000 model_save_path = '.' lr = 2e-5 class ESCIDataset(Dataset): def __init__(self, input): self.queries = [] self.posneg = [] with gzip.open(input) as jsonfile: for line in jsonfile.readlines(): query = json.loads(line) for doc in query['e']: self.queries.append(InputExample(texts=[query['query'], doc['title']], label=1.0)) for doc in query['s']: self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.1)) for doc in query['c']: self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.01)) for doc in query['i']: self.queries.append(InputExample(texts=[query['query'], doc['title']], label=0.0)) def __getitem__(self, item): return self.queries[item] def __len__(self): return len(self.queries) class ESCIEvalDataset(Dataset): def __init__(self, input): self.queries = [] with gzip.open(input) as jsonfile: for line in jsonfile.readlines(): query = json.loads(line) if len(query['e']) > 0 and len(query['i']) > 0: for p in query['e']: positive = p['title'] for n in query['i']: negative = n['title'] self.queries.append(InputExample(texts=[query['query'], positive, negative])) def __getitem__(self, item): return self.queries[item] def __len__(self): return len(self.queries) model = CrossEncoder(model_name, num_labels=1) model.max_seq_length = max_seq_length train_dataset = ESCIDataset(input='train-small.json.gz') eval_dataset = ESCIEvalDataset(input='test-small.json.gz') train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) samples = {} for query in eval_dataset.queries: qstr = query.texts[0] sample = samples.get(qstr, {'query': qstr}) positive = sample.get('positive', []) positive.append(query.texts[1]) sample['positive'] = positive negative = sample.get('negative', []) negative.append(query.texts[2]) sample['negative'] = negative samples[qstr] = sample evaluator = CERerankingEvaluator(samples=samples,name='esci') # Train the model model.fit(train_dataloader=train_dataloader, epochs=num_epochs, warmup_steps=warmup_steps, use_amp=True, optimizer_params = {'lr': lr}, evaluator=evaluator, # evaluation_steps=1000, output_path=model_save_path ) # Save the model model.save(model_save_path)