from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder from torch import nn import csv from torch.utils.data import DataLoader, Dataset import torch from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator import logging import json import random import gzip model_name = 'sentence-transformers/all-MiniLM-L6-v2' train_batch_size = 100 max_seq_length = 128 num_epochs = 1 warmup_steps = 1000 model_save_path = 'cos-exp' lr = 2e-5 class ESCIDataset(Dataset): def __init__(self, input): self.queries = [] with gzip.open(input) as jsonfile: for line in jsonfile.readlines(): query = json.loads(line) for p in query['e']: positive = p['title'] self.queries.append(InputExample(texts=[query['query'], positive], label=1.0)) for p in query['s']: positive = p['title'] self.queries.append(InputExample(texts=[query['query'], positive], label=0.1)) for p in query['c']: positive = p['title'] self.queries.append(InputExample(texts=[query['query'], positive], label=0.01)) for p in query['i']: positive = p['title'] self.queries.append(InputExample(texts=[query['query'], positive], label=0.0)) def __getitem__(self, item): return self.queries[item] def __len__(self): return len(self.queries) model = SentenceTransformer(model_name, device='cpu') model.max_seq_length = max_seq_length train_dataset = ESCIDataset(input='train-small.json.gz') eval_dataset = ESCIDataset(input='test-small.json.gz') train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) train_loss = losses.CosineSimilarityLoss(model=model) # samples = {} # for query in eval_dataset.queries: # qstr = query.texts[0] # sample = samples.get(qstr, {'query': qstr}) # positive = sample.get('positive', []) # positive.append(query.texts[1]) # sample['positive'] = positive # negative = sample.get('negative', []) # negative.append(query.texts[2]) # sample['negative'] = negative # samples[qstr] = sample # evaluator = RerankingEvaluator(samples=samples,name='esci') # Train the model model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=num_epochs, warmup_steps=warmup_steps, use_amp=True, # checkpoint_path=model_save_path, # checkpoint_save_steps=len(train_dataloader), optimizer_params = {'lr': lr}, # evaluator=evaluator, # evaluation_steps=1000, output_path=model_save_path ) # Save the model model.save(model_save_path)