LENS Embeddings
LENS is a model that produces Lexicon-based EmbeddiNgS (LENS) leveraging large language models. Each dimension of the embeddings is designed to correspond to a token cluster where semantically similar tokens are grouped together. These embeddings have a similar feature size as dense embeddings, with LENS-d8000 offering 8000-dimensional representations.
The technical report of LENS is available in Enhancing Lexicon-Based Text Embeddings with Large Language Models.
Usage
git clone https://huggingface.co/yibinlei/LENS-d8000
cd LENS-d8000
import torch
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer
from bidirectional_mistral import MistralBiForCausalLM
def get_detailed_instruct(task_instruction: str, query: str) -> str:
return f'<instruct>{task_instruction}\n<query>{query}'
def pooling_func(vecs: Tensor, pooling_mask: Tensor) -> Tensor:
# We use max-pooling for LENS.
return torch.max(torch.log(1 + torch.relu(vecs)) * pooling_mask.unsqueeze(-1), dim=1).values
# Prepare the data
instruction = "Given a web search query, retrieve relevant passages that answer the query."
queries = ["what is rba",
"what is oilskin fabric"]
instructed_queries = [get_detailed_instruct(instruction, query) for query in queries]
docs = ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal.",
"Today's oilskins (or oilies) typically come in two parts, jackets and trousers. Oilskin jackets are generally similar to common rubberized waterproofs."]
# Load the model and tokenizer
model = MistralBiForCausalLM.from_pretrained("yibinlei/LENS-d8000", ignore_mismatched_sizes=True)
model.lm_head = torch.load('lm_head.pth')
tokenizer = AutoTokenizer.from_pretrained("yibinlei/LENS-d8000")
# Preprocess the data
query_max_len, doc_max_len = 512, 512
instructed_query_inputs = tokenizer(
instructed_queries,
padding=True,
truncation=True,
return_tensors='pt',
max_length=query_max_len,
add_special_tokens=True
)
doc_inputs = tokenizer(
docs,
padding=True,
truncation=True,
return_tensors='pt',
max_length=doc_max_len,
add_special_tokens=True
)
# We perform pooling exclusively on the outputs of the query tokens, excluding outputs from the instruction.
query_only_mask = torch.zeros_like(instructed_query_inputs['input_ids'], dtype=instructed_query_inputs['attention_mask'].dtype)
special_token_id = tokenizer.convert_tokens_to_ids('<query>')
for idx, seq in enumerate(instructed_query_inputs['input_ids']):
special_pos = (seq == special_token_id).nonzero()
if len(special_pos) > 0:
query_start_pos = special_pos[-1].item()
query_only_mask[idx, query_start_pos:-2] = 1
else:
raise ValueError("No special token found")
# Obtain the embeddings
with torch.no_grad():
instructed_query_outputs = model(**instructed_query_inputs)
query_embeddings = pooling_func(instructed_query_outputs, query_only_mask)
doc_outputs = model(**doc_inputs)
# As the output of each token is used for predicting the next token, the pooling mask is shifted left by 1. The output of the final token EOS token is also excluded.
doc_inputs['attention_mask'][:, -2:] = 0
doc_embeddings = pooling_func(doc_outputs, doc_inputs['attention_mask'])
# Normalize the embeddings
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
# Compute the similarity
similarity = torch.matmul(query_embeddings, doc_embeddings.T)
- Downloads last month
- 298
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Evaluation results
- accuracy on MTEB AmazonCounterfactualClassification (en)test set self-reported93.687
- ap on MTEB AmazonCounterfactualClassification (en)test set self-reported74.448
- ap_weighted on MTEB AmazonCounterfactualClassification (en)test set self-reported74.448
- f1 on MTEB AmazonCounterfactualClassification (en)test set self-reported90.573
- f1_weighted on MTEB AmazonCounterfactualClassification (en)test set self-reported93.872
- main_score on MTEB AmazonCounterfactualClassification (en)test set self-reported93.687
- accuracy on MTEB AmazonPolarityClassification (default)test set self-reported97.068
- ap on MTEB AmazonPolarityClassification (default)test set self-reported95.710
- ap_weighted on MTEB AmazonPolarityClassification (default)test set self-reported95.710
- f1 on MTEB AmazonPolarityClassification (default)test set self-reported97.068