Restodecoca commited on
Commit
055befa
·
verified ·
1 Parent(s): dc34793

Update app.py

Browse files

adicionado bm25s revisado junto do bm25 retriever para melhor funcionamento

Files changed (1) hide show
  1. app.py +241 -53
app.py CHANGED
@@ -20,7 +20,7 @@ from llama_index.core.storage.chat_store import SimpleChatStore
20
  from llama_index.core.memory import ChatMemoryBuffer
21
  from llama_index.core.query_engine import RetrieverQueryEngine
22
  from llama_index.core.chat_engine import CondensePlusContextChatEngine
23
- from llama_index.retrievers.bm25 import BM25Retriever
24
  from llama_index.core.retrievers import QueryFusionRetriever
25
  from llama_index.vector_stores.chroma import ChromaVectorStore
26
  from llama_index.core import VectorStoreIndex
@@ -29,6 +29,238 @@ from llama_index.core import VectorStoreIndex
29
  # from llama_index.embeddings.huggingface import HuggingFaceEmbedding
30
  import chromadb
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  #Configuração da imagem da aba
33
  im = Image.open("pngegg.png")
34
  st.set_page_config(page_title = "Chatbot Carômetro", page_icon=im, layout = "wide")
@@ -38,8 +270,6 @@ os.makedirs("bm25_retriever", exist_ok=True)
38
  os.makedirs("chat_store", exist_ok=True)
39
  os.makedirs("chroma_db", exist_ok=True)
40
  os.makedirs("documentos", exist_ok=True)
41
- os.makedirs("curadoria", exist_ok=True)
42
- os.makedirs("chroma_db_curadoria", exist_ok=True)
43
 
44
  # Configuração do Streamlit
45
  st.sidebar.title("Configuração de LLM")
@@ -120,9 +350,7 @@ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
120
  chat_store_path = os.path.join("chat_store", "chat_store.json")
121
  documents_path = os.path.join("documentos")
122
  chroma_storage_path = os.path.join("chroma_db") # Diretório para persistência do Chroma
123
- chroma_storage_path_curadoria = os.path.join("chroma_db_curadoria") # Diretório para 'curadoria'
124
  bm25_persist_path = os.path.join("bm25_retriever")
125
- curadoria_path = os.path.join("curadoria")
126
 
127
  # Classe CSV Customizada (novo código)
128
  class CustomPandasCSVReader:
@@ -192,7 +420,7 @@ with open(credentials_path, 'w') as credentials_file:
192
 
193
  with open(token_path, 'w') as credentials_file:
194
  credentials_file.write(token_json)
195
-
196
  google_drive_reader = GoogleDriveReader(credentials_path=credentials_path)
197
  google_drive_reader._creds = google_drive_reader._get_credentials()
198
 
@@ -222,8 +450,6 @@ def download_original_files_from_folder(greader: GoogleDriveReader, pasta_docume
222
 
223
  #DADOS/QA_database/Documentos CSV/documentos
224
  pasta_documentos_drive = "1xVzo8s1D0blzR5ZB3m5k4dVWHuRmKUu-"
225
- #DADOS/QA_database/Documentos CSV/curadoria
226
- pasta_curadoria_drive = "1LRrdOkZy9p0FA3MQAyz-Ssj3ktKTWAwE"
227
 
228
  # Verifica e baixa arquivos se necessário (novo código)
229
  if not are_docs_downloaded(documents_path):
@@ -232,18 +458,14 @@ if not are_docs_downloaded(documents_path):
232
  else:
233
  logging.info("'documentos' já contém arquivos, ignorando download.")
234
 
235
- if not are_docs_downloaded(curadoria_path):
236
- logging.info("Baixando arquivos originais do Drive para 'curadoria'...")
237
- download_original_files_from_folder(google_drive_reader, pasta_curadoria_drive, curadoria_path)
238
- else:
239
- logging.info("'curadoria' já contém arquivos, ignorando download.")
240
-
241
  # Configuração de leitura de documentos
242
  file_extractor = {".csv": CustomPandasCSVReader()}
243
  documents = SimpleDirectoryReader(
244
  input_dir=documents_path,
245
  file_extractor=file_extractor,
246
- filename_as_id=True
 
 
247
  ).load_data()
248
 
249
  documents = clean_documents(documents)
@@ -266,7 +488,7 @@ if os.path.exists(chroma_storage_path):
266
  index = VectorStoreIndex.from_vector_store(vector_store)
267
  else:
268
  splitter = LangchainNodeParser(
269
- RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
270
  )
271
  index = VectorStoreIndex.from_documents(
272
  documents,
@@ -287,45 +509,11 @@ else:
287
  os.makedirs(bm25_persist_path, exist_ok=True)
288
  bm25_retriever.persist(bm25_persist_path)
289
 
290
- #Adicionado documentos na pasta curadoria, foi setado para 1200 o chunk pra receber pergunta, contexto e resposta
291
- curadoria_documents = SimpleDirectoryReader(
292
- input_dir=curadoria_path,
293
- file_extractor=file_extractor,
294
- filename_as_id=True
295
- ).load_data()
296
-
297
- curadoria_documents = clean_documents(curadoria_documents)
298
- curadoria_docstore = SimpleDocumentStore()
299
- curadoria_docstore.add_documents(curadoria_documents)
300
-
301
- db_curadoria = chromadb.PersistentClient(path=chroma_storage_path_curadoria)
302
- chroma_collection_curadoria = db_curadoria.get_or_create_collection("dense_vectors_curadoria")
303
- vector_store_curadoria = ChromaVectorStore(chroma_collection=chroma_collection_curadoria)
304
-
305
- # Configuração do StorageContext para 'curadoria'
306
- storage_context_curadoria = StorageContext.from_defaults(
307
- docstore=curadoria_docstore, vector_store=vector_store_curadoria
308
- )
309
-
310
- # Criação/Recarregamento do índice com embeddings para 'curadoria'
311
- if os.path.exists(chroma_storage_path_curadoria):
312
- curadoria_index = VectorStoreIndex.from_vector_store(vector_store_curadoria)
313
- else:
314
- curadoria_splitter = LangchainNodeParser(
315
- RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=100)
316
- )
317
- curadoria_index = VectorStoreIndex.from_documents(
318
- curadoria_documents, storage_context=storage_context_curadoria, transformations=[curadoria_splitter]
319
- )
320
- vector_store_curadoria.persist()
321
-
322
- curadoria_retriever = curadoria_index.as_retriever(similarity_top_k=2)
323
-
324
  # Combinação de Retrievers (Embeddings + BM25)
325
  vector_retriever = index.as_retriever(similarity_top_k=2)
326
  retriever = QueryFusionRetriever(
327
- [vector_retriever, bm25_retriever, curadoria_retriever],
328
- similarity_top_k=2,
329
  num_queries=0,
330
  mode="reciprocal_rerank",
331
  use_async=True,
@@ -397,4 +585,4 @@ if user_input:
397
 
398
  # Remover o cursor após a conclusão
399
  message_placeholder.markdown(assistant_message)
400
- st.session_state.chat_history.append(f"assistant: {assistant_message}")
 
20
  from llama_index.core.memory import ChatMemoryBuffer
21
  from llama_index.core.query_engine import RetrieverQueryEngine
22
  from llama_index.core.chat_engine import CondensePlusContextChatEngine
23
+ #from llama_index.retrievers.bm25 import BM25Retriever
24
  from llama_index.core.retrievers import QueryFusionRetriever
25
  from llama_index.vector_stores.chroma import ChromaVectorStore
26
  from llama_index.core import VectorStoreIndex
 
29
  # from llama_index.embeddings.huggingface import HuggingFaceEmbedding
30
  import chromadb
31
 
32
+ ###############################################################################
33
+ # MONKEY PATCH EM bm25s #
34
+ ###############################################################################
35
+ import bm25s
36
+
37
+ # Guardamos a referência da função original
38
+ orig_find_newline_positions = bm25s.utils.corpus.find_newline_positions
39
+
40
+ def patched_find_newline_positions(path, show_progress=True, leave_progress=True):
41
+ """
42
+ Versão 'gambiarra' da função original, forçando uso de encoding='utf-8'
43
+ e ignorando erros de decodificação. Assim, evitamos UnicodeDecodeError
44
+ mesmo que o arquivo contenha caracteres fora da faixa UTF-8.
45
+
46
+ (Esta referência é real, baseada em ajustes de leitura de arquivos do Python.)
47
+ """
48
+ path = str(path)
49
+ indexes = []
50
+
51
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
52
+ indexes.append(f.tell())
53
+ file_size = os.path.getsize(path)
54
+
55
+ try:
56
+ from tqdm.auto import tqdm
57
+ pbar = tqdm(
58
+ total=file_size,
59
+ desc="Finding newlines for mmindex",
60
+ unit="B",
61
+ unit_scale=True,
62
+ leave=leave_progress,
63
+ disable=not show_progress,
64
+ )
65
+ except ImportError:
66
+ pbar = None
67
+
68
+ while True:
69
+ line = f.readline()
70
+ if not line:
71
+ break
72
+ t = f.tell()
73
+ indexes.append(t)
74
+ if pbar is not None:
75
+ pbar.update(t - indexes[-2])
76
+
77
+ if pbar is not None:
78
+ pbar.close()
79
+
80
+ return indexes[:-1]
81
+
82
+ # Aplicamos nosso patch
83
+ bm25s.utils.corpus.find_newline_positions = patched_find_newline_positions
84
+ ###############################################################################
85
+ # CLASSE BM25Retriever (AJUSTADA PARA ENCODING) #
86
+ ###############################################################################
87
+ import json
88
+ import Stemmer
89
+
90
+ from llama_index.core.base.base_retriever import BaseRetriever
91
+ from llama_index.core.callbacks.base import CallbackManager
92
+ from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
93
+ from llama_index.core.schema import (
94
+ BaseNode,
95
+ IndexNode,
96
+ NodeWithScore,
97
+ QueryBundle,
98
+ MetadataMode,
99
+ )
100
+ from llama_index.core.vector_stores.utils import (
101
+ node_to_metadata_dict,
102
+ metadata_dict_to_node,
103
+ )
104
+ from typing import cast
105
+
106
+ logger = logging.getLogger(__name__)
107
+
108
+ DEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
109
+ DEFAULT_PERSIST_FILENAME = "retriever.json"
110
+
111
+
112
+ class BM25Retriever(BaseRetriever):
113
+ """
114
+ Implementação customizada do algoritmo BM25 com a lib bm25s, incluindo um
115
+ 'monkey patch' para contornar problemas de decodificação de caracteres.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ nodes: Optional[List[BaseNode]] = None,
121
+ stemmer: Optional[Stemmer.Stemmer] = None,
122
+ language: str = "en",
123
+ existing_bm25: Optional[bm25s.BM25] = None,
124
+ similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
125
+ callback_manager: Optional[CallbackManager] = None,
126
+ objects: Optional[List[IndexNode]] = None,
127
+ object_map: Optional[dict] = None,
128
+ verbose: bool = False,
129
+ ) -> None:
130
+ self.stemmer = stemmer or Stemmer.Stemmer("english")
131
+ self.similarity_top_k = similarity_top_k
132
+
133
+ if existing_bm25 is not None:
134
+ # Usa instância BM25 existente
135
+ self.bm25 = existing_bm25
136
+ self.corpus = existing_bm25.corpus
137
+ else:
138
+ # Cria uma nova instância BM25 a partir de 'nodes'
139
+ if nodes is None:
140
+ raise ValueError("É preciso fornecer 'nodes' ou um 'existing_bm25'.")
141
+
142
+ self.corpus = [node_to_metadata_dict(node) for node in nodes]
143
+ corpus_tokens = bm25s.tokenize(
144
+ [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes],
145
+ stopwords=language,
146
+ stemmer=self.stemmer,
147
+ show_progress=verbose,
148
+ )
149
+ self.bm25 = bm25s.BM25()
150
+ self.bm25.index(corpus_tokens, show_progress=verbose)
151
+
152
+ super().__init__(
153
+ callback_manager=callback_manager,
154
+ object_map=object_map,
155
+ objects=objects,
156
+ verbose=verbose,
157
+ )
158
+
159
+ @classmethod
160
+ def from_defaults(
161
+ cls,
162
+ index: Optional[VectorStoreIndex] = None,
163
+ nodes: Optional[List[BaseNode]] = None,
164
+ docstore: Optional["BaseDocumentStore"] = None,
165
+ stemmer: Optional[Stemmer.Stemmer] = None,
166
+ language: str = "en",
167
+ similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
168
+ verbose: bool = False,
169
+ tokenizer: Optional[Any] = None,
170
+ ) -> "BM25Retriever":
171
+ if tokenizer is not None:
172
+ logger.warning(
173
+ "O parâmetro 'tokenizer' foi descontinuado e será removido "
174
+ "no futuro. Use um Stemmer do PyStemmer para melhor controle."
175
+ )
176
+
177
+ if sum(bool(val) for val in [index, nodes, docstore]) != 1:
178
+ raise ValueError("Passe exatamente um entre 'index', 'nodes' ou 'docstore'.")
179
+
180
+ if index is not None:
181
+ docstore = index.docstore
182
+
183
+ if docstore is not None:
184
+ nodes = cast(List[BaseNode], list(docstore.docs.values()))
185
+
186
+ assert nodes is not None, (
187
+ "Não foi possível determinar os nodes. Verifique seus parâmetros."
188
+ )
189
+
190
+ return cls(
191
+ nodes=nodes,
192
+ stemmer=stemmer,
193
+ language=language,
194
+ similarity_top_k=similarity_top_k,
195
+ verbose=verbose,
196
+ )
197
+
198
+ def get_persist_args(self) -> Dict[str, Any]:
199
+ """Dicionário com os parâmetros de persistência a serem salvos."""
200
+ return {
201
+ DEFAULT_PERSIST_ARGS[key]: getattr(self, key)
202
+ for key in DEFAULT_PERSIST_ARGS
203
+ if hasattr(self, key)
204
+ }
205
+
206
+ def persist(self, path: str, **kwargs: Any) -> None:
207
+ """
208
+ Persiste o retriever em um diretório, incluindo
209
+ a estrutura do BM25 e o corpus em JSON.
210
+ """
211
+ self.bm25.save(path, corpus=self.corpus, **kwargs)
212
+ with open(
213
+ os.path.join(path, DEFAULT_PERSIST_FILENAME),
214
+ "wt",
215
+ encoding="utf-8",
216
+ errors="ignore",
217
+ ) as f:
218
+ json.dump(self.get_persist_args(), f, indent=2, ensure_ascii=False)
219
+
220
+ @classmethod
221
+ def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever":
222
+ """
223
+ Carrega o retriever de um diretório, incluindo o BM25 e o corpus.
224
+ Devido ao nosso patch, ignoramos qualquer erro de decodificação
225
+ que eventualmente apareça.
226
+ """
227
+ bm25_obj = bm25s.BM25.load(path, load_corpus=True, **kwargs)
228
+ with open(
229
+ os.path.join(path, DEFAULT_PERSIST_FILENAME),
230
+ "rt",
231
+ encoding="utf-8",
232
+ errors="ignore",
233
+ ) as f:
234
+ retriever_data = json.load(f)
235
+
236
+ return cls(existing_bm25=bm25_obj, **retriever_data)
237
+
238
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
239
+ """Recupera nós relevantes a partir do BM25."""
240
+ query = query_bundle.query_str
241
+ tokenized_query = bm25s.tokenize(
242
+ query, stemmer=self.stemmer, show_progress=self._verbose
243
+ )
244
+ indexes, scores = self.bm25.retrieve(
245
+ tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
246
+ )
247
+
248
+ # bm25s retorna lista de listas, pois suporta batched queries
249
+ indexes = indexes[0]
250
+ scores = scores[0]
251
+
252
+ nodes: List[NodeWithScore] = []
253
+ for idx, score in zip(indexes, scores):
254
+ if isinstance(idx, dict):
255
+ node = metadata_dict_to_node(idx)
256
+ else:
257
+ node_dict = self.corpus[int(idx)]
258
+ node = metadata_dict_to_node(node_dict)
259
+
260
+ nodes.append(NodeWithScore(node=node, score=float(score)))
261
+
262
+ return nodes
263
+
264
  #Configuração da imagem da aba
265
  im = Image.open("pngegg.png")
266
  st.set_page_config(page_title = "Chatbot Carômetro", page_icon=im, layout = "wide")
 
270
  os.makedirs("chat_store", exist_ok=True)
271
  os.makedirs("chroma_db", exist_ok=True)
272
  os.makedirs("documentos", exist_ok=True)
 
 
273
 
274
  # Configuração do Streamlit
275
  st.sidebar.title("Configuração de LLM")
 
350
  chat_store_path = os.path.join("chat_store", "chat_store.json")
351
  documents_path = os.path.join("documentos")
352
  chroma_storage_path = os.path.join("chroma_db") # Diretório para persistência do Chroma
 
353
  bm25_persist_path = os.path.join("bm25_retriever")
 
354
 
355
  # Classe CSV Customizada (novo código)
356
  class CustomPandasCSVReader:
 
420
 
421
  with open(token_path, 'w') as credentials_file:
422
  credentials_file.write(token_json)
423
+
424
  google_drive_reader = GoogleDriveReader(credentials_path=credentials_path)
425
  google_drive_reader._creds = google_drive_reader._get_credentials()
426
 
 
450
 
451
  #DADOS/QA_database/Documentos CSV/documentos
452
  pasta_documentos_drive = "1xVzo8s1D0blzR5ZB3m5k4dVWHuRmKUu-"
 
 
453
 
454
  # Verifica e baixa arquivos se necessário (novo código)
455
  if not are_docs_downloaded(documents_path):
 
458
  else:
459
  logging.info("'documentos' já contém arquivos, ignorando download.")
460
 
 
 
 
 
 
 
461
  # Configuração de leitura de documentos
462
  file_extractor = {".csv": CustomPandasCSVReader()}
463
  documents = SimpleDirectoryReader(
464
  input_dir=documents_path,
465
  file_extractor=file_extractor,
466
+ filename_as_id=True,
467
+ recursive=True
468
+ #Recursive caso tenha varias pastas no drive
469
  ).load_data()
470
 
471
  documents = clean_documents(documents)
 
488
  index = VectorStoreIndex.from_vector_store(vector_store)
489
  else:
490
  splitter = LangchainNodeParser(
491
+ RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
492
  )
493
  index = VectorStoreIndex.from_documents(
494
  documents,
 
509
  os.makedirs(bm25_persist_path, exist_ok=True)
510
  bm25_retriever.persist(bm25_persist_path)
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  # Combinação de Retrievers (Embeddings + BM25)
513
  vector_retriever = index.as_retriever(similarity_top_k=2)
514
  retriever = QueryFusionRetriever(
515
+ [vector_retriever, bm25_retriever],
516
+ similarity_top_k=3,
517
  num_queries=0,
518
  mode="reciprocal_rerank",
519
  use_async=True,
 
585
 
586
  # Remover o cursor após a conclusão
587
  message_placeholder.markdown(assistant_message)
588
+ st.session_state.chat_history.append(f"assistant: {assistant_message}")