File size: 5,709 Bytes
1871bfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00cad52
1871bfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import logging

from typing import Any, List
from pydantic import Extra

import pinecone
import google.generativeai as genai

from llama_index import (
    ServiceContext,
    PromptHelper,
    VectorStoreIndex
)
from llama_index.vector_stores import PineconeVectorStore
from llama_index.storage.storage_context import StorageContext
from llama_index.node_parser import SimpleNodeParser
from llama_index.text_splitter import TokenTextSplitter
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms import (
    CustomLLM,
    CompletionResponse,
    CompletionResponseGen,
    LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback

class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
    def __init__(
        self,
        model_name: str = 'models/embedding-gecko-001',
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self._model_name = model_name

    @classmethod
    def class_name(cls) -> str:
        return 'PaLMEmbeddings'

    def gen_embeddings(self, text: str) -> List[float]:
        return genai.generate_embeddings(self._model_name, text)

    def _get_query_embedding(self, query: str) -> List[float]:
        embeddings = self.gen_embeddings(query)
        return embeddings['embedding']

    def _get_text_embedding(self, text: str) -> List[float]:
        embeddings = self.gen_embeddings(text)
        return embeddings['embedding']

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        embeddings = [
            self.gen_embeddings(text)['embedding'] for text in texts
        ]
        return embeddings
    
    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)
    
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
    def __init__(
        self,
        model_name: str = 'models/text-bison-001',
        context_window: int = 8196,
        num_output: int = 1024,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self._model_name = model_name
        self._context_window = context_window
        self._num_output = num_output
        
    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self._context_window,
            num_output=self._num_output,
            model_name=self._model_name
        )

    def gen_texts(self, prompt):
            logging.debug(f"prompt: {prompt}")
            response = genai.generate_text(
                model=self._model_name,
                prompt=prompt,
                safety_settings=[
                    {
                        'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
                        'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
                    },
                ]
            )
            logging.debug(f"response:\n{response}")
            return response.candidates[0]['output']

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        text = self.gen_texts(prompt)
        return CompletionResponse(text=text)

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        raise NotImplementedError()
    
class LlamaIndexPaLM():
    def __init__(
        self,
        emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
        model: LlamaIndexPaLMText = LlamaIndexPaLMText()
    ) -> None:
        self.emb_model = emb_model
        self.llm = model
        
        # Google Generative AI
        genai.configure(api_key=os.environ['PALM_API_KEY'])

        # Pinecone
        pinecone.init(
            api_key=os.environ['PINECONE_API_KEY'],
            environment=os.getenv('PINECONE_ENV', 'us-west1-gcp-free')
        )

        # model metadata
        CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
        NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
        TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
        TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
        TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
        TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)

        self.node_parser = SimpleNodeParser.from_defaults(
            text_splitter=TokenTextSplitter(
                chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
            )
        )

        self.prompt_helper = PromptHelper(
            context_window=CONTEXT_WINDOW,
            num_output=NUM_OUTPUT,
            chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
            chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
        )

        self.service_context = ServiceContext.from_defaults(
            llm=self.llm,
            embed_model=self.emb_model,
            node_parser=self.node_parser,
            prompt_helper=self.prompt_helper,
        )
    
    def set_index_from_pinecone(
        self, 
        index_name: str = 'experience'
    ) -> None:
        # Pinecone VectorStore
        pinecone_index = pinecone.Index(index_name)
        self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True)
        self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
        return None
    
    def generate_response(
        self,
        query: str
    ) -> str:
        response = self.pinecone_index.as_query_engine().query(query)
        return response.response