|
import json |
|
import time |
|
import faiss |
|
|
|
|
|
|
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from abc import ABCMeta, abstractmethod |
|
|
|
|
|
class BaseEmbedding(metaclass=ABCMeta): |
|
""" |
|
Base Embedding interface. |
|
""" |
|
|
|
@abstractmethod |
|
def to_embeddings(self, data, **kwargs): |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def dimension(self) -> int: |
|
return 0 |
|
|
|
|
|
|
|
class SBERT(BaseEmbedding): |
|
"""Generate sentence embedding for given text using pretrained models of Sentence Transformers. |
|
|
|
:param model: model name, defaults to 'all-MiniLM-L6-v2'. |
|
:type model: str |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
from gptcache.embedding import SBERT |
|
|
|
test_sentence = 'Hello, world.' |
|
encoder = SBERT('all-MiniLM-L6-v2') |
|
embed = encoder.to_embeddings(test_sentence) |
|
""" |
|
|
|
def __init__(self, model: str = "all-MiniLM-L6-v2"): |
|
self.model = SentenceTransformer(model) |
|
self.model.eval() |
|
self.__dimension = None |
|
|
|
def to_embeddings(self, data, **_): |
|
"""Generate embedding given text input |
|
|
|
:param data: text in string. |
|
:type data: str |
|
|
|
:return: a text embedding in shape of (dim,). |
|
""" |
|
if not isinstance(data, list): |
|
data = [data] |
|
emb = self.model.encode(data) |
|
_, dim = emb.shape |
|
if not self.__dimension: |
|
self.__dimension = dim |
|
return np.array(emb).astype("float32") |
|
|
|
@property |
|
def dimension(self): |
|
"""Embedding dimension. |
|
|
|
:return: embedding dimension |
|
""" |
|
if not self.__dimension: |
|
embd = self.model.encode(["foo"]) |
|
_, self.__dimension = embd.shape |
|
return self.__dimension |
|
|
|
|
|
|
|
|
|
def init_cache(embedding_model: str = "all-MiniLM-L6-v2"): |
|
"""Initializes the cache with a Faiss index and an SBERT model. |
|
|
|
Args: |
|
embedding_model (str): The name of the SBERT model to use. |
|
|
|
Returns: |
|
tuple: (index, encoder) where |
|
- index is a Faiss index for storing embeddings. |
|
- encoder is an SBERT model instance. |
|
""" |
|
|
|
encoder = SBERT(embedding_model) |
|
dimension = encoder.dimension |
|
print(dimension) |
|
index = faiss.IndexFlatL2(dimension) |
|
if index.is_trained: |
|
print('Index initialized and ready for use') |
|
|
|
return index, encoder |
|
|
|
|
|
def retrieve_cache(json_file): |
|
try: |
|
with open(json_file, 'r') as file: |
|
cache = json.load(file) |
|
except FileNotFoundError: |
|
cache = {'questions': [], 'answers': []} |
|
|
|
return cache |
|
|
|
|
|
|
|
def store_cache(json_file, cache): |
|
with open(json_file, 'w', encoding = 'utf-8') as file: |
|
json.dump(cache, file) |
|
|
|
|
|
class Cache: |
|
def __init__(self, embedding = "all-MiniLM-L6-v2" , json_file="cache_file.json", thresold=0.5, max_response=100, eviction_policy='FIFO'): |
|
"""Initializes the semantic cache. |
|
|
|
Args: |
|
json_file (str): The name of the JSON file where the cache is stored. |
|
thresold (float): The threshold for the Euclidean distance to determine if a question is similar. |
|
max_response (int): The maximum number of responses the cache can store. |
|
eviction_policy (str): The policy for evicting items from the cache. |
|
This can be any policy, but 'FIFO' (First In First Out) has been implemented for now. |
|
If None, no eviction policy will be applied. |
|
""" |
|
|
|
|
|
self.index, self.encoder = init_cache(embedding) |
|
|
|
|
|
|
|
|
|
self.euclidean_threshold = thresold |
|
self.is_missed = True |
|
self.json_file = json_file |
|
self.cache = retrieve_cache(self.json_file) |
|
self.max_response = max_response |
|
self.eviction_policy = eviction_policy |
|
|
|
def evict(self): |
|
|
|
"""Evicts an item from the cache based on the eviction policy.""" |
|
if self.eviction_policy and len(self.cache["questions"]) > self.max_response: |
|
for _ in range((len(self.cache["questions"]) - self.max_response)): |
|
if self.eviction_policy == 'FIFO': |
|
self.cache["questions"].pop(0) |
|
self.cache["answers"].pop(0) |
|
def cached_hit(self, question: str) -> str: |
|
"""Handles the cache hit logic by retrieving the answer from the cache. |
|
|
|
Args: |
|
question (str): The input question. |
|
embedding: The embedding of the question. |
|
|
|
Returns: |
|
str: The cached answer. |
|
""" |
|
|
|
embedding = self.encoder.to_embeddings([question]) |
|
self.index.nprobe = 8 |
|
D, I = self.index.search(embedding, 1) |
|
print(D) |
|
if D[0] >= 0: |
|
if I[0][0] >= 0 and D[0][0] / 100 <= self.euclidean_threshold: |
|
row_id = int(I[0][0]) |
|
print('Answer recovered from Cache.') |
|
print(f'Distance: {D[0][0]:.3f} (Threshold: {self.euclidean_threshold})') |
|
print(f'Found in cache at row: {row_id} with score: {D[0][0]:.3f}') |
|
self.is_missed =False |
|
return self.cache['answers'][row_id] |
|
self.is_missed = True |
|
return embedding , self.is_missed |
|
|
|
|
|
def cache_miss(self, question: str, embedding , answer) -> str: |
|
"""Handles the cache miss logic by querying the model and updating the cache. |
|
|
|
Args: |
|
question (str): The input question. |
|
embedding: The embedding of the question take from cache_hit if hit nothing |
|
answer (str) : The answer from LLMs |
|
Returns: |
|
Append to cache and return answer. |
|
""" |
|
|
|
|
|
self.cache['questions'].append(question) |
|
self.cache['answers'].append(answer) |
|
|
|
print('Answer not found in cache, appending new answer.') |
|
print(f'Response: {answer}') |
|
|
|
|
|
self.index.add(embedding) |
|
|
|
|
|
self.evict() |
|
|
|
|
|
store_cache(self.json_file, self.cache) |
|
self.is_missed = False |
|
return answer |
|
|
|
|