Spaces:
Building
Building
import os | |
import logging | |
from typing import Any, List | |
from pydantic import Extra | |
import pinecone | |
import google.generativeai as genai | |
from llama_index import ( | |
ServiceContext, | |
PromptHelper, | |
VectorStoreIndex | |
) | |
from llama_index.vector_stores import PineconeVectorStore | |
from llama_index.storage.storage_context import StorageContext | |
from llama_index.node_parser import SimpleNodeParser | |
from llama_index.text_splitter import TokenTextSplitter | |
from llama_index.embeddings.base import BaseEmbedding | |
from llama_index.llms import ( | |
CustomLLM, | |
CompletionResponse, | |
CompletionResponseGen, | |
LLMMetadata, | |
) | |
from llama_index.llms.base import llm_completion_callback | |
class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow): | |
def __init__( | |
self, | |
model_name: str = 'models/embedding-gecko-001', | |
**kwargs: Any, | |
) -> None: | |
super().__init__(**kwargs) | |
self._model_name = model_name | |
def class_name(cls) -> str: | |
return 'PaLMEmbeddings' | |
def gen_embeddings(self, text: str) -> List[float]: | |
return genai.generate_embeddings(self._model_name, text) | |
def _get_query_embedding(self, query: str) -> List[float]: | |
embeddings = self.gen_embeddings(query) | |
return embeddings['embedding'] | |
def _get_text_embedding(self, text: str) -> List[float]: | |
embeddings = self.gen_embeddings(text) | |
return embeddings['embedding'] | |
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
embeddings = [ | |
self.gen_embeddings(text)['embedding'] for text in texts | |
] | |
return embeddings | |
async def _aget_query_embedding(self, query: str) -> List[float]: | |
return self._get_query_embedding(query) | |
async def _aget_text_embedding(self, text: str) -> List[float]: | |
return self._get_text_embedding(text) | |
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow): | |
def __init__( | |
self, | |
model_name: str = 'models/text-bison-001', | |
context_window: int = 8196, | |
num_output: int = 1024, | |
**kwargs: Any, | |
) -> None: | |
super().__init__(**kwargs) | |
self._model_name = model_name | |
self._context_window = context_window | |
self._num_output = num_output | |
def metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
return LLMMetadata( | |
context_window=self._context_window, | |
num_output=self._num_output, | |
model_name=self._model_name | |
) | |
def gen_texts(self, prompt): | |
logging.debug(f"prompt: {prompt}") | |
response = genai.generate_text( | |
model=self._model_name, | |
prompt=prompt, | |
safety_settings=[ | |
{ | |
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED, | |
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE, | |
}, | |
] | |
) | |
logging.debug(f"response:\n{response}") | |
return response.candidates[0]['output'] | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
text = self.gen_texts(prompt) | |
return CompletionResponse(text=text) | |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
raise NotImplementedError() | |
class LlamaIndexPaLM(): | |
def __init__( | |
self, | |
emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(), | |
model: LlamaIndexPaLMText = LlamaIndexPaLMText() | |
) -> None: | |
self.emb_model = emb_model | |
self.llm = model | |
# Google Generative AI | |
genai.configure(api_key=os.environ['PALM_API_KEY']) | |
# Pinecone | |
pinecone.init( | |
api_key=os.environ['PINECONE_API_KEY'], | |
environment=os.getenv('PINECONE_ENV', 'us-west1-gcp-free') | |
) | |
# model metadata | |
CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196) | |
NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024) | |
TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512) | |
TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20) | |
TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1) | |
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None) | |
self.node_parser = SimpleNodeParser.from_defaults( | |
text_splitter=TokenTextSplitter( | |
chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP | |
) | |
) | |
self.prompt_helper = PromptHelper( | |
context_window=CONTEXT_WINDOW, | |
num_output=NUM_OUTPUT, | |
chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO, | |
chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT | |
) | |
self.service_context = ServiceContext.from_defaults( | |
llm=self.llm, | |
embed_model=self.emb_model, | |
node_parser=self.node_parser, | |
prompt_helper=self.prompt_helper, | |
) | |
def set_index_from_pinecone( | |
self, | |
index_name: str = 'experience' | |
) -> None: | |
# Pinecone VectorStore | |
pinecone_index = pinecone.Index(index_name) | |
self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True) | |
self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context) | |
return None | |
def generate_response( | |
self, | |
query: str | |
) -> str: | |
response = self.pinecone_index.as_query_engine().query(query) | |
return response.response |