g-palm-chat / src /llamaindex_palm.py
hoshingakag's picture
fix v0.1
00cad52
raw
history blame
5.71 kB
import os
import logging
from typing import Any, List
from pydantic import Extra
import pinecone
import google.generativeai as genai
from llama_index import (
ServiceContext,
PromptHelper,
VectorStoreIndex
)
from llama_index.vector_stores import PineconeVectorStore
from llama_index.storage.storage_context import StorageContext
from llama_index.node_parser import SimpleNodeParser
from llama_index.text_splitter import TokenTextSplitter
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms import (
CustomLLM,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback
class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
def __init__(
self,
model_name: str = 'models/embedding-gecko-001',
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._model_name = model_name
@classmethod
def class_name(cls) -> str:
return 'PaLMEmbeddings'
def gen_embeddings(self, text: str) -> List[float]:
return genai.generate_embeddings(self._model_name, text)
def _get_query_embedding(self, query: str) -> List[float]:
embeddings = self.gen_embeddings(query)
return embeddings['embedding']
def _get_text_embedding(self, text: str) -> List[float]:
embeddings = self.gen_embeddings(text)
return embeddings['embedding']
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
embeddings = [
self.gen_embeddings(text)['embedding'] for text in texts
]
return embeddings
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
def __init__(
self,
model_name: str = 'models/text-bison-001',
context_window: int = 8196,
num_output: int = 1024,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._model_name = model_name
self._context_window = context_window
self._num_output = num_output
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self._context_window,
num_output=self._num_output,
model_name=self._model_name
)
def gen_texts(self, prompt):
logging.debug(f"prompt: {prompt}")
response = genai.generate_text(
model=self._model_name,
prompt=prompt,
safety_settings=[
{
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
},
]
)
logging.debug(f"response:\n{response}")
return response.candidates[0]['output']
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
text = self.gen_texts(prompt)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
class LlamaIndexPaLM():
def __init__(
self,
emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
model: LlamaIndexPaLMText = LlamaIndexPaLMText()
) -> None:
self.emb_model = emb_model
self.llm = model
# Google Generative AI
genai.configure(api_key=os.environ['PALM_API_KEY'])
# Pinecone
pinecone.init(
api_key=os.environ['PINECONE_API_KEY'],
environment=os.getenv('PINECONE_ENV', 'us-west1-gcp-free')
)
# model metadata
CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
self.node_parser = SimpleNodeParser.from_defaults(
text_splitter=TokenTextSplitter(
chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
)
)
self.prompt_helper = PromptHelper(
context_window=CONTEXT_WINDOW,
num_output=NUM_OUTPUT,
chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
)
self.service_context = ServiceContext.from_defaults(
llm=self.llm,
embed_model=self.emb_model,
node_parser=self.node_parser,
prompt_helper=self.prompt_helper,
)
def set_index_from_pinecone(
self,
index_name: str = 'experience'
) -> None:
# Pinecone VectorStore
pinecone_index = pinecone.Index(index_name)
self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True)
self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
return None
def generate_response(
self,
query: str
) -> str:
response = self.pinecone_index.as_query_engine().query(query)
return response.response